diff --git a/.prettierignore b/.prettierignore index 95097cbec1..d8801ad33a 100644 --- a/.prettierignore +++ b/.prettierignore @@ -14,6 +14,7 @@ packages/amplify-graphql-api-construct-tests/amplify-e2e-reports packages/amplify-graphql-api-construct/README.md packages/amplify-graphql-api-construct/tsconfig.json packages/amplify-graphql-conversation-transformer/src/__tests__/schemas/*.graphql +packages/amplify-graphql-conversation-transformer/src/resolvers/*.template.js packages/amplify-data-construct/README.md packages/amplify-data-construct/tsconfig.json packages/amplify-graphql-model-transformer/publish-notification-lambda/lib/ diff --git a/packages/amplify-data-construct/.jsii b/packages/amplify-data-construct/.jsii index a8f5136727..6a21d85cbc 100644 --- a/packages/amplify-data-construct/.jsii +++ b/packages/amplify-data-construct/.jsii @@ -6,7 +6,7 @@ ] }, "bundled": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "0.0.0-test-20240925144932", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-auth-transformer": "4.1.1", @@ -3970,5 +3970,5 @@ }, "types": {}, "version": "1.10.1", - "fingerprint": "18+pcRZF+HEurnEkAAQsQVHj1oGYG+i0hcJuPIeDP7U=" + "fingerprint": "XIFtuREMLP4ld3+h28KKIQLwEL1DxCLKiMqJERSh2wY=" } \ No newline at end of file diff --git a/packages/amplify-data-construct/package.json b/packages/amplify-data-construct/package.json index 4b92a0c453..faa020e0d0 100644 --- a/packages/amplify-data-construct/package.json +++ b/packages/amplify-data-construct/package.json @@ -157,7 +157,7 @@ "semver" ], "dependencies": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "0.0.0-test-20240925144932", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-api-construct": "1.13.0", diff --git a/packages/amplify-graphql-api-construct/.jsii b/packages/amplify-graphql-api-construct/.jsii index f7f07ff206..2b4370c17c 100644 --- a/packages/amplify-graphql-api-construct/.jsii +++ b/packages/amplify-graphql-api-construct/.jsii @@ -6,7 +6,7 @@ ] }, "bundled": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "0.0.0-test-20240925144932", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-auth-transformer": "4.1.1", @@ -8887,5 +8887,5 @@ } }, "version": "1.13.0", - "fingerprint": "fBEq7TqsVFMCTBRZ+ohz6O/ygsZR+7evKnwJZkkf+/A=" + "fingerprint": "bwGSsIcxucYASo+SO2R0k8vwf8deNlWleoyg6ajeR1s=" } \ No newline at end of file diff --git a/packages/amplify-graphql-api-construct/package.json b/packages/amplify-graphql-api-construct/package.json index b409472c05..25ffe394cc 100644 --- a/packages/amplify-graphql-api-construct/package.json +++ b/packages/amplify-graphql-api-construct/package.json @@ -158,7 +158,7 @@ "semver" ], "dependencies": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "0.0.0-test-20240925144932", "@aws-amplify/backend-output-schemas": "^1.0.0", "@aws-amplify/backend-output-storage": "^1.0.0", "@aws-amplify/graphql-auth-transformer": "4.1.1", diff --git a/packages/amplify-graphql-conversation-transformer/package.json b/packages/amplify-graphql-conversation-transformer/package.json index 68240bd504..22191fcdda 100644 --- a/packages/amplify-graphql-conversation-transformer/package.json +++ b/packages/amplify-graphql-conversation-transformer/package.json @@ -16,14 +16,15 @@ "access": "public" }, "scripts": { - "build": "tsc", + "build": "tsc && yarn copy-js-resolver-templates", "watch": "tsc -w", "clean": "rimraf ./lib", + "copy-js-resolver-templates": "cp ./src/resolvers/*.template.js ./lib/resolvers", "test": "jest", "extract-api": "ts-node ../../scripts/extract-api.ts" }, "dependencies": { - "@aws-amplify/ai-constructs": "^0.1.4", + "@aws-amplify/ai-constructs": "0.0.0-test-20240925144932", "@aws-amplify/graphql-directives": "2.2.0", "@aws-amplify/graphql-index-transformer": "3.0.3", "@aws-amplify/graphql-model-transformer": "3.0.3", diff --git a/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap b/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap index 0b994fbf8c..696843aae8 100644 --- a/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap +++ b/packages/amplify-graphql-conversation-transformer/src/__tests__/__snapshots__/amplify-graphql-conversation-transformer.test.ts.snap @@ -47,13 +47,13 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "Fn::Join": [ "", [ - " import { util } from '@aws-appsync/utils'; + "import { util } from '@aws-appsync/utils'; - export function request(ctx) { - const { args, identity, request, prev } = ctx; - - const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const graphqlApiEndpoint = '", +export function request(ctx) { + const { args, request, prev } = ctx; + + const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; + const graphqlApiEndpoint = '", { "Fn::GetAtt": [ "GraphQLAPI", @@ -62,60 +62,71 @@ exports[`ConversationTransformer valid schemas should transform conversation rou }, "'; - const messages = prev.result.items; - const responseMutation = { - name: 'createAssistantResponsePirateChat', - inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', - selectionSet, - }; - const currentMessageId = ctx.stash.defaultValues.id; - const modelConfiguration = { + const messages = prev.result.items; + const responseMutation = { + name: 'createAssistantResponsePirateChat', + inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', + selectionSet, + }; + const currentMessageId = ctx.stash.defaultValues.id; + const modelConfiguration = { modelId: 'anthropic.claude-3-haiku-20240307-v1:0', systemPrompt: "You are a helpful chatbot. Answer questions to the best of your ability.", inferenceConfiguration: {"temperature":0.5,"topP":0.9,"maxTokens":100}, }; - const clientTools = args.toolConfiguration?.tools?.map((tool) => { return { ...tool.toolSpec }}); - const toolsConfiguration = { + const clientTools = args.toolConfiguration?.tools?.map((tool) => { + return { ...tool.toolSpec }; + }); + const toolsConfiguration = { clientTools }; - const authHeader = request.headers['authorization']; - const payload = { - conversationId: args.conversationId, - currentMessageId, - responseMutation, - graphqlApiEndpoint, - modelConfiguration, - request: { headers: { authorization: authHeader }}, - messages, - toolsConfiguration, - }; - - return { - operation: 'Invoke', - payload, - invocationType: 'Event' - }; - } + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + + const authHeader = request.headers['authorization']; + const payload = { + conversationId: args.conversationId, + currentMessageId, + responseMutation, + graphqlApiEndpoint, + modelConfiguration, + request: { headers: { authorization: authHeader } }, + messageHistoryQuery, + toolsConfiguration, + }; + + return { + operation: 'Invoke', + payload, + invocationType: 'Event', + }; +} export function response(ctx) { - let success = true; if (ctx.error) { util.appendError(ctx.error.message, ctx.error.type); - success = false; } const response = { - __typename: 'ConversationMessagePirateChat', - id: ctx.stash.defaultValues.id, - conversationId: ctx.args.conversationId, - role: 'user', - content: ctx.args.content, - createdAt: ctx.stash.defaultValues.createdAt, - updatedAt: ctx.stash.defaultValues.updatedAt, + __typename: 'ConversationMessagePirateChat', + id: ctx.stash.defaultValues.id, + conversationId: ctx.args.conversationId, + role: 'user', + content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, + createdAt: ctx.stash.defaultValues.createdAt, + updatedAt: ctx.stash.defaultValues.updatedAt, }; return response; -}", +} +", ], ], } @@ -168,13 +179,13 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "Fn::Join": [ "", [ - " import { util } from '@aws-appsync/utils'; + "import { util } from '@aws-appsync/utils'; - export function request(ctx) { - const { args, identity, request, prev } = ctx; - const toolDefinitions = {"tools":[{"name":"listTodos","description":"lists todos","inputSchema":{"json":{"type":"object","properties":{},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"items { content isDone id createdAt updatedAt owner } nextToken","propertyTypes":{},"queryName":"listTodos"}}]}; - const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const graphqlApiEndpoint = '", +export function request(ctx) { + const { args, request, prev } = ctx; + const toolDefinitions = {"tools":[{"name":"listTodos","description":"lists todos","inputSchema":{"json":{"type":"object","properties":{},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"items { content isDone id createdAt updatedAt owner } nextToken","propertyTypes":{},"queryName":"listTodos"}}]}; + const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; + const graphqlApiEndpoint = '", { "Fn::GetAtt": [ "GraphQLAPI", @@ -183,62 +194,73 @@ exports[`ConversationTransformer valid schemas should transform conversation rou }, "'; - const messages = prev.result.items; - const responseMutation = { - name: 'createAssistantResponsePirateChat', - inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', - selectionSet, - }; - const currentMessageId = ctx.stash.defaultValues.id; - const modelConfiguration = { + const messages = prev.result.items; + const responseMutation = { + name: 'createAssistantResponsePirateChat', + inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', + selectionSet, + }; + const currentMessageId = ctx.stash.defaultValues.id; + const modelConfiguration = { modelId: 'anthropic.claude-3-haiku-20240307-v1:0', systemPrompt: "You are a helpful chatbot. Answer questions to the best of your ability.", }; - const clientTools = args.toolConfiguration?.tools?.map((tool) => { return { ...tool.toolSpec }}); - const dataTools = toolDefinitions.tools; + const clientTools = args.toolConfiguration?.tools?.map((tool) => { + return { ...tool.toolSpec }; + }); + const dataTools = toolDefinitions.tools; const toolsConfiguration = { dataTools, clientTools, }; - const authHeader = request.headers['authorization']; - const payload = { - conversationId: args.conversationId, - currentMessageId, - responseMutation, - graphqlApiEndpoint, - modelConfiguration, - request: { headers: { authorization: authHeader }}, - messages, - toolsConfiguration, - }; - - return { - operation: 'Invoke', - payload, - invocationType: 'Event' - }; - } + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + + const authHeader = request.headers['authorization']; + const payload = { + conversationId: args.conversationId, + currentMessageId, + responseMutation, + graphqlApiEndpoint, + modelConfiguration, + request: { headers: { authorization: authHeader } }, + messageHistoryQuery, + toolsConfiguration, + }; + + return { + operation: 'Invoke', + payload, + invocationType: 'Event', + }; +} export function response(ctx) { - let success = true; if (ctx.error) { util.appendError(ctx.error.message, ctx.error.type); - success = false; } const response = { - __typename: 'ConversationMessagePirateChat', - id: ctx.stash.defaultValues.id, - conversationId: ctx.args.conversationId, - role: 'user', - content: ctx.args.content, - createdAt: ctx.stash.defaultValues.createdAt, - updatedAt: ctx.stash.defaultValues.updatedAt, + __typename: 'ConversationMessagePirateChat', + id: ctx.stash.defaultValues.id, + conversationId: ctx.args.conversationId, + role: 'user', + content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, + createdAt: ctx.stash.defaultValues.createdAt, + updatedAt: ctx.stash.defaultValues.updatedAt, }; return response; -}", +} +", ], ], } @@ -291,13 +313,13 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "Fn::Join": [ "", [ - " import { util } from '@aws-appsync/utils'; + "import { util } from '@aws-appsync/utils'; - export function request(ctx) { - const { args, identity, request, prev } = ctx; - const toolDefinitions = {"tools":[{"name":"listCustomers","description":"Provides data about the customer sending a message","inputSchema":{"json":{"type":"object","properties":{},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"items { name email activeCart { products { name price } customerId id createdAt updatedAt owner } orderHistory { items { products { name price } customerId id createdAt updatedAt owner } nextToken } id createdAt updatedAt owner } nextToken","propertyTypes":{},"queryName":"listCustomers"}}]}; - const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const graphqlApiEndpoint = '", +export function request(ctx) { + const { args, request, prev } = ctx; + const toolDefinitions = {"tools":[{"name":"listCustomers","description":"Provides data about the customer sending a message","inputSchema":{"json":{"type":"object","properties":{},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"items { name email activeCart { products { name price } customerId id createdAt updatedAt owner } orderHistory { items { products { name price } customerId id createdAt updatedAt owner } nextToken } id createdAt updatedAt owner } nextToken","propertyTypes":{},"queryName":"listCustomers"}}]}; + const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; + const graphqlApiEndpoint = '", { "Fn::GetAtt": [ "GraphQLAPI", @@ -306,62 +328,73 @@ exports[`ConversationTransformer valid schemas should transform conversation rou }, "'; - const messages = prev.result.items; - const responseMutation = { - name: 'createAssistantResponsePirateChat', - inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', - selectionSet, - }; - const currentMessageId = ctx.stash.defaultValues.id; - const modelConfiguration = { + const messages = prev.result.items; + const responseMutation = { + name: 'createAssistantResponsePirateChat', + inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', + selectionSet, + }; + const currentMessageId = ctx.stash.defaultValues.id; + const modelConfiguration = { modelId: 'anthropic.claude-3-haiku-20240307-v1:0', systemPrompt: "You are a helpful chatbot. Answer questions to the best of your ability.", }; - const clientTools = args.toolConfiguration?.tools?.map((tool) => { return { ...tool.toolSpec }}); - const dataTools = toolDefinitions.tools; + const clientTools = args.toolConfiguration?.tools?.map((tool) => { + return { ...tool.toolSpec }; + }); + const dataTools = toolDefinitions.tools; const toolsConfiguration = { dataTools, clientTools, }; - const authHeader = request.headers['authorization']; - const payload = { - conversationId: args.conversationId, - currentMessageId, - responseMutation, - graphqlApiEndpoint, - modelConfiguration, - request: { headers: { authorization: authHeader }}, - messages, - toolsConfiguration, - }; - - return { - operation: 'Invoke', - payload, - invocationType: 'Event' - }; - } + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + + const authHeader = request.headers['authorization']; + const payload = { + conversationId: args.conversationId, + currentMessageId, + responseMutation, + graphqlApiEndpoint, + modelConfiguration, + request: { headers: { authorization: authHeader } }, + messageHistoryQuery, + toolsConfiguration, + }; + + return { + operation: 'Invoke', + payload, + invocationType: 'Event', + }; +} export function response(ctx) { - let success = true; if (ctx.error) { util.appendError(ctx.error.message, ctx.error.type); - success = false; } const response = { - __typename: 'ConversationMessagePirateChat', - id: ctx.stash.defaultValues.id, - conversationId: ctx.args.conversationId, - role: 'user', - content: ctx.args.content, - createdAt: ctx.stash.defaultValues.createdAt, - updatedAt: ctx.stash.defaultValues.updatedAt, + __typename: 'ConversationMessagePirateChat', + id: ctx.stash.defaultValues.id, + conversationId: ctx.args.conversationId, + role: 'user', + content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, + createdAt: ctx.stash.defaultValues.createdAt, + updatedAt: ctx.stash.defaultValues.updatedAt, }; return response; -}", +} +", ], ], } @@ -414,13 +447,13 @@ exports[`ConversationTransformer valid schemas should transform conversation rou "Fn::Join": [ "", [ - " import { util } from '@aws-appsync/utils'; + "import { util } from '@aws-appsync/utils'; - export function request(ctx) { - const { args, identity, request, prev } = ctx; - const toolDefinitions = {"tools":[{"name":"getTemperature","description":"does a thing","inputSchema":{"json":{"type":"object","properties":{"city":{"type":"string","description":"A UTF-8 character sequence."}},"required":["city"]}},"graphqlRequestInputDescriptor":{"selectionSet":"value unit","propertyTypes":{"city":"String!"},"queryName":"getTemperature"}},{"name":"plus","description":"does a different thing","inputSchema":{"json":{"type":"object","properties":{"a":{"type":"number","description":"A signed 32-bit integer value."},"b":{"type":"number","description":"A signed 32-bit integer value."}},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"","propertyTypes":{"a":"Int","b":"Int"},"queryName":"plus"}}]}; - const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; - const graphqlApiEndpoint = '", +export function request(ctx) { + const { args, request, prev } = ctx; + const toolDefinitions = {"tools":[{"name":"getTemperature","description":"does a thing","inputSchema":{"json":{"type":"object","properties":{"city":{"type":"string","description":"A UTF-8 character sequence."}},"required":["city"]}},"graphqlRequestInputDescriptor":{"selectionSet":"value unit","propertyTypes":{"city":"String!"},"queryName":"getTemperature"}},{"name":"plus","description":"does a different thing","inputSchema":{"json":{"type":"object","properties":{"a":{"type":"number","description":"A signed 32-bit integer value."},"b":{"type":"number","description":"A signed 32-bit integer value."}},"required":[]}},"graphqlRequestInputDescriptor":{"selectionSet":"","propertyTypes":{"a":"Int","b":"Int"},"queryName":"plus"}}]}; + const selectionSet = 'id conversationId content { image { format source { bytes }} text toolUse { toolUseId name input } toolResult { status toolUseId content { json text image { format source { bytes }} document { format name source { bytes }} }}} role owner createdAt updatedAt'; + const graphqlApiEndpoint = '", { "Fn::GetAtt": [ "GraphQLAPI", @@ -429,62 +462,73 @@ exports[`ConversationTransformer valid schemas should transform conversation rou }, "'; - const messages = prev.result.items; - const responseMutation = { - name: 'createAssistantResponsePirateChat', - inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', - selectionSet, - }; - const currentMessageId = ctx.stash.defaultValues.id; - const modelConfiguration = { + const messages = prev.result.items; + const responseMutation = { + name: 'createAssistantResponsePirateChat', + inputTypeName: 'CreateConversationMessagePirateChatAssistantInput', + selectionSet, + }; + const currentMessageId = ctx.stash.defaultValues.id; + const modelConfiguration = { modelId: 'anthropic.claude-3-haiku-20240307-v1:0', systemPrompt: "You are a helpful chatbot. Answer questions to the best of your ability.", }; - const clientTools = args.toolConfiguration?.tools?.map((tool) => { return { ...tool.toolSpec }}); - const dataTools = toolDefinitions.tools; + const clientTools = args.toolConfiguration?.tools?.map((tool) => { + return { ...tool.toolSpec }; + }); + const dataTools = toolDefinitions.tools; const toolsConfiguration = { dataTools, clientTools, }; - const authHeader = request.headers['authorization']; - const payload = { - conversationId: args.conversationId, - currentMessageId, - responseMutation, - graphqlApiEndpoint, - modelConfiguration, - request: { headers: { authorization: authHeader }}, - messages, - toolsConfiguration, - }; - - return { - operation: 'Invoke', - payload, - invocationType: 'Event' - }; - } + const messageHistoryQuery = { + getQueryName: 'getConversationMessagePirateChat', + getQueryInputTypeName: 'ID', + listQueryName: 'listConversationMessagePirateChats', + listQueryInputTypeName: 'ModelConversationMessagePirateChatFilterInput', + listQueryLimit: undefined, + }; + + const authHeader = request.headers['authorization']; + const payload = { + conversationId: args.conversationId, + currentMessageId, + responseMutation, + graphqlApiEndpoint, + modelConfiguration, + request: { headers: { authorization: authHeader } }, + messageHistoryQuery, + toolsConfiguration, + }; + + return { + operation: 'Invoke', + payload, + invocationType: 'Event', + }; +} export function response(ctx) { - let success = true; if (ctx.error) { util.appendError(ctx.error.message, ctx.error.type); - success = false; } const response = { - __typename: 'ConversationMessagePirateChat', - id: ctx.stash.defaultValues.id, - conversationId: ctx.args.conversationId, - role: 'user', - content: ctx.args.content, - createdAt: ctx.stash.defaultValues.createdAt, - updatedAt: ctx.stash.defaultValues.updatedAt, + __typename: 'ConversationMessagePirateChat', + id: ctx.stash.defaultValues.id, + conversationId: ctx.args.conversationId, + role: 'user', + content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, + createdAt: ctx.stash.defaultValues.createdAt, + updatedAt: ctx.stash.defaultValues.updatedAt, }; return response; -}", +} +", ], ], } diff --git a/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts b/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts index a0a9153c5f..d3ac79cdb0 100644 --- a/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts +++ b/packages/amplify-graphql-conversation-transformer/src/__tests__/amplify-graphql-conversation-transformer.test.ts @@ -43,6 +43,11 @@ describe('ConversationTransformer', () => { const schema = parse(out.schema); validateModelSchema(schema); + + expect( + out.stacks.ConversationMessagePirateChat.Resources![`ListConversationMessage${toUpper(routeName)}Resolver`].Properties + .PipelineConfig.Functions, + ).toHaveLength(5); }); }); @@ -79,17 +84,28 @@ describe('ConversationTransformer', () => { }); const assertResolverSnapshot = (routeName: string, resources: DeploymentResources) => { - const resolverCode = resources.rootStack.Resources?.[`Mutation${routeName}Resolver`]?.['Properties']['Code']; + const resolverCode = getResolverResource(routeName, resources.rootStack.Resources)['Properties']['Code']; + expect(resolverCode).toBeDefined(); + expect(resolverCode).toMatchSnapshot(); + + const resolverFnCode = getResolverFnResource(routeName, resources); + expect(resolverFnCode).toBeDefined(); + expect(resolverFnCode).toMatchSnapshot(); +}; + +const getResolverResource = (mutationName: string, resources?: Record): Record => { + const resolverName = `Mutation${mutationName}Resolver`; + return resources?.[resolverName]; +}; + +const getResolverFnResource = (mutationName: string, resources: DeploymentResources): string => { const resolverFnCode = resources.rootStack.Resources && - Object.entries(resources.rootStack.Resources).find(([key, _]) => key.startsWith(`Mutation${toUpper(routeName)}DataResolverFn`))?.[1][ + Object.entries(resources.rootStack.Resources).find(([key, _]) => key.startsWith(`Mutation${toUpper(mutationName)}DataResolverFn`))?.[1][ 'Properties' ]['Code']; - expect(resolverCode).toBeDefined(); - expect(resolverCode).toMatchSnapshot(); - expect(resolverFnCode).toBeDefined(); - expect(resolverFnCode).toMatchSnapshot(); + return resolverFnCode; }; const defaultAuthConfig: AppSyncAuthConfiguration = { diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver-fn.template.js new file mode 100644 index 0000000000..6ce17c873f --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver-fn.template.js @@ -0,0 +1,43 @@ +import { util, extensions } from '@aws-appsync/utils'; + +export function request(ctx) { + ctx.stash.hasAuth = true; + const isAuthorized = false; + + if (util.authType() === 'User Pool Authorization') { + if (!isAuthorized) { + const authFilter = []; + let ownerClaim0 = ctx.identity['claims']['sub']; + ctx.args.owner = ownerClaim0; + const currentClaim1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; + if (ownerClaim0 && currentClaim1) { + ownerClaim0 = ownerClaim0 + '::' + currentClaim1; + authFilter.push({ owner: { eq: ownerClaim0 } }); + } + const role0_0 = ctx.identity['claims']['sub']; + if (role0_0) { + authFilter.push({ owner: { eq: role0_0 } }); + } + // we can just reuse currentClaim1 here, but doing this (for now) to mirror the existing + // vtl auth resolver. + const role0_1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; + if (role0_1) { + authFilter.push({ owner: { eq: role0_1 } }); + } + if (authFilter.length !== 0) { + ctx.stash.authFilter = { or: authFilter }; + } + } + } + if (!isAuthorized && ctx.stash.authFilter.length === 0) { + util.unauthorized(); + } + ctx.args.filter = { ...ctx.args.filter, and: [{ conversationId: { eq: ctx.args.conversationId } }] }; + return { version: '2018-05-29', payload: {} }; +} + +export function response(ctx) { + const subscriptionFilter = util.transform.toSubscriptionFilter(ctx.args.filter); + extensions.setSubscriptionFilter(subscriptionFilter); + return null; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver.ts index 5493b5d3d8..18dd96d009 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-messages-subscription-resolver.ts @@ -1,6 +1,8 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; /** * Creates and returns the mapping template for the conversation message subscription resolver. @@ -8,74 +10,8 @@ import { dedent } from 'ts-dedent'; * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const conversationMessageSubscriptionMappingTamplate = (): MappingTemplateProvider => { - const req = createAssistantMessagesSubscriptionRequestFunction(); - const res = createAssistantMessagesSubscriptionResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for the conversation message subscription resolver. - * This function handles the authorization and filtering of the conversation messages for owner auth. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ -const createAssistantMessagesSubscriptionRequestFunction = (): string => { - const requestFunctionString = ` - export function request(ctx) { - ctx.stash.hasAuth = true; - const isAuthorized = false; - - if (util.authType() === 'User Pool Authorization') { - if (!isAuthorized) { - const authFilter = []; - let ownerClaim0 = ctx.identity['claims']['sub']; - ctx.args.owner = ownerClaim0; - const currentClaim1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; - if (ownerClaim0 && currentClaim1) { - ownerClaim0 = ownerClaim0 + '::' + currentClaim1; - authFilter.push({ owner: { eq: ownerClaim0 } }) - } - const role0_0 = ctx.identity['claims']['sub']; - if (role0_0) { - authFilter.push({ owner: { eq: role0_0 } }); - } - // we can just reuse currentClaim1 here, but doing this (for now) to mirror the existing - // vtl auth resolver. - const role0_1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; - if (role0_1) { - authFilter.push({ owner: { eq: role0_1 }}); - } - if (authFilter.length !== 0) { - ctx.stash.authFilter = { or: authFilter }; - } - } - } - if (!isAuthorized && ctx.stash.authFilter.length === 0) { - util.unauthorized(); - } - ctx.args.filter = { ...ctx.args.filter, and: [{ conversationId: { eq: ctx.args.conversationId }}]}; - return { version: '2018-05-29', payload: {} }; - }`; - - return requestFunctionString; -}; - -/** - * Creates the response function for the conversation message subscription resolver. - * This function handles the subscription filter and sets the subscription filter for the conversation messages. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ -const createAssistantMessagesSubscriptionResponseFunction = (): string => { - const responseFunctionString = ` - import { util, extensions } from '@aws-appsync/utils'; - - export function response(ctx) { - const subscriptionFilter = util.transform.toSubscriptionFilter(ctx.args.filter); - extensions.setSubscriptionFilter(subscriptionFilter); - return null; - }`; - - return responseFunctionString; +export const conversationMessageSubscriptionMappingTamplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'assistant-messages-subscription-resolver-fn.template.js'), 'utf8'); + const templateName = `Subscription.${config.field.name.value}.assistant-message.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js new file mode 100644 index 0000000000..d35ff74551 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver-fn.template.js @@ -0,0 +1,58 @@ +import { util } from '@aws-appsync/utils'; + +/** + * Sends a request to the attached data source + * @param {import('@aws-appsync/utils').Context} ctx the context + * @returns {*} the request + */ +export function request(ctx) { + const owner = ctx.identity['claims']['sub']; + ctx.stash.owner = owner; + const { conversationId, content, associatedUserMessageId } = ctx.args.input; + const updatedAt = util.time.nowISO8601(); + + const expression = 'SET #assistantContent = :assistantContent, #updatedAt = :updatedAt'; + const expressionNames = { '#assistantContent': 'assistantContent', '#updatedAt': 'updatedAt' }; + const expressionValues = { ':assistantContent': content, ':updatedAt': updatedAt }; + const condition = JSON.parse( + util.transform.toDynamoDBConditionExpression({ + owner: { eq: owner }, + conversationId: { eq: conversationId }, + }), + ); + return { + operation: 'UpdateItem', + key: util.dynamodb.toMapValues({ id: associatedUserMessageId }), + condition, + update: { + expression, + expressionNames, + expressionValues: util.dynamodb.toMapValues(expressionValues), + }, + }; +} + +/** + * Returns the resolver result + * @param {import('@aws-appsync/utils').Context} ctx the context + * @returns {*} the result + */ +export function response(ctx) { + // Update with response logic + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + + const { conversationId, content, associatedUserMessageId } = ctx.args.input; + const { createdAt, updatedAt } = ctx.result; + + return { + id: associatedUserMessageId, + content, + conversationId, + role: 'assistant', + owner: ctx.stash.owner, + createdAt, + updatedAt, + }; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts index df5c21ae0d..ebfcf45d5e 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/assistant-mutation-resolver.ts @@ -1,6 +1,8 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; /** * Creates and returns the mapping template for the assistant mutation resolver. @@ -8,89 +10,8 @@ import { dedent } from 'ts-dedent'; * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const assistantMutationResolver = (): MappingTemplateProvider => { - const req = createAssistantMutationRequestFunction(); - const res = createAssistantMutationResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for the assistant mutation resolver. - * This function handles the update of the assistant's response in the conversation. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ -const createAssistantMutationRequestFunction = (): string => { - const requestFunctionString = ` - import { util } from '@aws-appsync/utils'; - - /** - * Sends a request to the attached data source - * @param {import('@aws-appsync/utils').Context} ctx the context - * @returns {*} the request - */ - export function request(ctx) { - const owner = ctx.identity['claims']['sub']; - ctx.stash.owner = owner; - const { conversationId, content, associatedUserMessageId } = ctx.args.input; - const updatedAt = util.time.nowISO8601(); - - const expression = 'SET #assistantContent = :assistantContent, #updatedAt = :updatedAt'; - const expressionNames = { '#assistantContent': 'assistantContent', '#updatedAt': 'updatedAt' }; - const expressionValues = { ':assistantContent': content, ':updatedAt': updatedAt }; - const condition = JSON.parse( - util.transform.toDynamoDBConditionExpression({ - owner: { eq: owner }, - conversationId: { eq: conversationId } - }) - ); - return { - operation: 'UpdateItem', - key: util.dynamodb.toMapValues({ id: associatedUserMessageId }), - condition, - update: { - expression, - expressionNames, - expressionValues: util.dynamodb.toMapValues(expressionValues), - } - }; - }`; - - return requestFunctionString; -}; - -/** - * Creates the response function for the assistant mutation resolver. - * This function handles the processing of the response after the mutation. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ -const createAssistantMutationResponseFunction = (): string => { - const responseFunctionString = ` - /** - * Returns the resolver result - * @param {import('@aws-appsync/utils').Context} ctx the context - * @returns {*} the result - */ - export function response(ctx) { - // Update with response logic - if (ctx.error) { - util.error(ctx.error.message, ctx.error.type); - } - - const { conversationId, content, associatedUserMessageId } = ctx.args.input; - const { createdAt, updatedAt } = ctx.result; - - return { - id: associatedUserMessageId, - content, - conversationId, - role: 'assistant', - owner: ctx.stash.owner, - createdAt, - updatedAt, - }; - }`; - - return responseFunctionString; +export const assistantMutationResolver = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'assistant-mutation-resolver-fn.template.js'), 'utf8'); + const templateName = `Mutation.${config.field.name.value}.assistant-response.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver-fn.template.js new file mode 100644 index 0000000000..d38983b42b --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver-fn.template.js @@ -0,0 +1,38 @@ +export function request(ctx) { + ctx.stash.hasAuth = true; + const isAuthorized = false; + + if (util.authType() === 'User Pool Authorization') { + if (!isAuthorized) { + const authFilter = []; + let ownerClaim0 = ctx.identity['claims']['sub']; + ctx.args.owner = ownerClaim0; + const currentClaim1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; + if (ownerClaim0 && currentClaim1) { + ownerClaim0 = ownerClaim0 + '::' + currentClaim1; + authFilter.push({ owner: { eq: ownerClaim0 } }); + } + const role0_0 = ctx.identity['claims']['sub']; + if (role0_0) { + authFilter.push({ owner: { eq: role0_0 } }); + } + // we can just reuse currentClaim1 here, but doing this (for now) to mirror the existing + // vtl auth resolver. + const role0_1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; + if (role0_1) { + authFilter.push({ owner: { eq: role0_1 } }); + } + if (authFilter.length !== 0) { + ctx.stash.authFilter = { or: authFilter }; + } + } + } + if (!isAuthorized && ctx.stash.authFilter.length === 0) { + util.unauthorized(); + } + return { version: '2018-05-29', payload: {} }; +} + +export function response(ctx) { + return {}; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver.ts index ca7c166b2d..0766e0cfc8 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/auth-resolver.ts @@ -1,6 +1,8 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; /** * Creates and returns the mapping template for the auth resolver. @@ -8,70 +10,8 @@ import { dedent } from 'ts-dedent'; * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const authMappingTemplate = (): MappingTemplateProvider => { - const req = createAuthRequestFunction(); - const res = createAuthResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for the auth resolver. - * This function handles authorization logic for owner based auth. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ -const createAuthRequestFunction = (): string => { - const requestFunctionString = ` - export function request(ctx) { - ctx.stash.hasAuth = true; - const isAuthorized = false; - - if (util.authType() === 'User Pool Authorization') { - if (!isAuthorized) { - const authFilter = []; - let ownerClaim0 = ctx.identity['claims']['sub']; - ctx.args.owner = ownerClaim0; - const currentClaim1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; - if (ownerClaim0 && currentClaim1) { - ownerClaim0 = ownerClaim0 + '::' + currentClaim1; - authFilter.push({ owner: { eq: ownerClaim0 } }) - } - const role0_0 = ctx.identity['claims']['sub']; - if (role0_0) { - authFilter.push({ owner: { eq: role0_0 } }); - } - // we can just reuse currentClaim1 here, but doing this (for now) to mirror the existing - // vtl auth resolver. - const role0_1 = ctx.identity['claims']['username'] ?? ctx.identity['claims']['cognito:username']; - if (role0_1) { - authFilter.push({ owner: { eq: role0_1 }}); - } - if (authFilter.length !== 0) { - ctx.stash.authFilter = { or: authFilter }; - } - } - } - if (!isAuthorized && ctx.stash.authFilter.length === 0) { - util.unauthorized(); - } - return { version: '2018-05-29', payload: {} }; - }`; - - return requestFunctionString; -}; - -/** - * Creates the response function for the auth resolver. - * This function currently returns an empty object. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ - -const createAuthResponseFunction = (): string => { - const responseFunctionString = ` - export function response(ctx) { - return {}; - }`; - - return responseFunctionString; +export const authMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'auth-resolver-fn.template.js'), 'utf8'); + const templateName = `Mutation.${config.field.name.value}.auth.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver-fn.template.js new file mode 100644 index 0000000000..e5064eac1c --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver-fn.template.js @@ -0,0 +1,15 @@ +export function request(ctx) { + ctx.stash.defaultValues = ctx.stash.defaultValues ?? {}; + ctx.stash.defaultValues.id = util.autoId(); + const createdAt = util.time.nowISO8601(); + ctx.stash.defaultValues.createdAt = createdAt; + ctx.stash.defaultValues.updatedAt = createdAt; + return { + version: '2018-05-09', + payload: {}, + }; +} + +export function response(ctx) { + return {}; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver.ts index 6665218f0c..5683b01e24 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/init-resolver.ts @@ -1,6 +1,8 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; /** * Creates and returns the mapping template for the init resolver. @@ -8,47 +10,8 @@ import { dedent } from 'ts-dedent'; * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const initMappingTemplate = (): MappingTemplateProvider => { - const req = createInitRequestFunction(); - const res = createInitResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for the init resolver. - * This function sets up default values for id, createdAt, and updatedAt. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ - -const createInitRequestFunction = (): string => { - const requestFunctionString = ` - export function request(ctx) { - ctx.stash.defaultValues = ctx.stash.defaultValues ?? {}; - ctx.stash.defaultValues.id = util.autoId(); - const createdAt = util.time.nowISO8601(); - ctx.stash.defaultValues.createdAt = createdAt; - ctx.stash.defaultValues.updatedAt = createdAt; - return { - version: '2018-05-09', - payload: {} - }; - }`; - - return requestFunctionString; -}; - -/** - * Creates the response function for the init resolver. - * This function currently returns an empty object. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ -const createInitResponseFunction = (): string => { - const responseFunctionString = ` - export function response(ctx) { - return {}; - }`; - - return responseFunctionString; +export const initMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'init-resolver-fn.template.js'), 'utf8'); + const templateName = `Mutation.${config.field.name.value}.init.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js new file mode 100644 index 0000000000..8d008d7070 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver-fn.template.js @@ -0,0 +1,66 @@ +import { util } from '@aws-appsync/utils'; + +export function request(ctx) { + const { args, request, prev } = ctx; + [[TOOL_DEFINITIONS_LINE]] + const selectionSet = '[[SELECTION_SET]]'; + const graphqlApiEndpoint = '[[GRAPHQL_API_ENDPOINT]]'; + + const messages = prev.result.items; + const responseMutation = { + name: '[[RESPONSE_MUTATION_NAME]]', + inputTypeName: '[[RESPONSE_MUTATION_INPUT_TYPE_NAME]]', + selectionSet, + }; + const currentMessageId = ctx.stash.defaultValues.id; + [[MODEL_CONFIGURATION_LINE]] + + const clientTools = args.toolConfiguration?.tools?.map((tool) => { + return { ...tool.toolSpec }; + }); + [[TOOLS_CONFIGURATION_LINE]] + + const messageHistoryQuery = { + getQueryName: '[[GET_QUERY_NAME]]', + getQueryInputTypeName: '[[GET_QUERY_INPUT_TYPE_NAME]]', + listQueryName: '[[LIST_QUERY_NAME]]', + listQueryInputTypeName: '[[LIST_QUERY_INPUT_TYPE_NAME]]', + listQueryLimit: [[LIST_QUERY_LIMIT]], + }; + + const authHeader = request.headers['authorization']; + const payload = { + conversationId: args.conversationId, + currentMessageId, + responseMutation, + graphqlApiEndpoint, + modelConfiguration, + request: { headers: { authorization: authHeader } }, + messageHistoryQuery, + toolsConfiguration, + }; + + return { + operation: 'Invoke', + payload, + invocationType: 'Event', + }; +} + +export function response(ctx) { + if (ctx.error) { + util.appendError(ctx.error.message, ctx.error.type); + } + const response = { + __typename: '[[MESSAGE_MODEL_NAME]]', + id: ctx.stash.defaultValues.id, + conversationId: ctx.args.conversationId, + role: 'user', + content: ctx.args.content, + aiContext: ctx.args.aiContext, + toolConfiguration: ctx.args.toolConfiguration, + createdAt: ctx.stash.defaultValues.createdAt, + updatedAt: ctx.stash.defaultValues.updatedAt, + }; + return response; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts index ca297576bf..d9495e07dd 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/invoke-lambda-resolver.ts @@ -1,7 +1,11 @@ import { TransformerContextProvider, MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import dedent from 'ts-dedent'; +import { toUpper } from 'graphql-transformer-common'; +import pluralize from 'pluralize'; /** * Creates a mapping template for invoking a Lambda function in the context of a GraphQL conversation. @@ -14,27 +18,55 @@ export const invokeLambdaMappingTemplate = ( config: ConversationDirectiveConfiguration, ctx: TransformerContextProvider, ): MappingTemplateProvider => { - const req = createInvokeLambdaRequestFunction(config, ctx); - const res = createInvokeLambdaResponseFunction(config); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); + const { TOOL_DEFINITIONS_LINE, TOOLS_CONFIGURATION_LINE } = generateToolLines(config); + const SELECTION_SET = selectionSet; + const GRAPHQL_API_ENDPOINT = ctx.api.graphqlUrl; + const MODEL_CONFIGURATION_LINE = generateModelConfigurationLine(config); + const RESPONSE_MUTATION_NAME = config.responseMutationName; + const RESPONSE_MUTATION_INPUT_TYPE_NAME = config.responseMutationInputTypeName; + const MESSAGE_MODEL_NAME = config.messageModel.messageModel.name.value; + + // TODO: Create and add these values to `ConversationDirectiveConfiguration` in an earlier step and + // access them here. + const GET_QUERY_NAME = `getConversationMessage${toUpper(config.field.name.value)}`; + const GET_QUERY_INPUT_TYPE_NAME = 'ID'; + const LIST_QUERY_NAME = `listConversationMessage${toUpper(pluralize(config.field.name.value))}`; + const LIST_QUERY_INPUT_TYPE_NAME = `ModelConversationMessage${toUpper(config.field.name.value)}FilterInput`; + const LIST_QUERY_LIMIT = 'undefined'; + + const substitutions = { + TOOL_DEFINITIONS_LINE, + TOOLS_CONFIGURATION_LINE, + SELECTION_SET, + GRAPHQL_API_ENDPOINT, + MODEL_CONFIGURATION_LINE, + RESPONSE_MUTATION_NAME, + RESPONSE_MUTATION_INPUT_TYPE_NAME, + MESSAGE_MODEL_NAME, + GET_QUERY_NAME, + GET_QUERY_INPUT_TYPE_NAME, + LIST_QUERY_NAME, + LIST_QUERY_INPUT_TYPE_NAME, + LIST_QUERY_LIMIT, + }; + + let resolver = fs.readFileSync(path.join(__dirname, 'invoke-lambda-resolver-fn.template.js'), 'utf8'); + Object.entries(substitutions).forEach(([key, value]) => { + const replaced = resolver.replace(new RegExp(`\\[\\[${key}\\]\\]`, 'g'), value); + resolver = replaced; + }); + + // This unfortunately needs to be an inline template because an s3 mapping template doesn't allow the CDK + // to substitute token values, which is necessary for this resolver function due to its reference of + // `ctx.api.graphqlUrl`. + return MappingTemplate.inlineTemplateFromString(resolver); }; -/** - * Creates a request function for invoking a Lambda function in the context of a GraphQL conversation. - * This function prepares the necessary data and configuration for the Lambda invocation. - * - * @param {ConversationDirectiveConfiguration} config - The configuration for the conversation directive. - * @param {TransformerContextProvider} ctx - The transformer context provider. - * @returns {MappingTemplateProvider} A function that generates the request mapping template. - */ -const createInvokeLambdaRequestFunction = (config: ConversationDirectiveConfiguration, ctx: TransformerContextProvider): string => { - const { responseMutationInputTypeName, responseMutationName } = config; +const generateToolLines = (config: ConversationDirectiveConfiguration) => { const toolDefinitions = JSON.stringify(config.toolSpec); - const toolDefinitionsLine = toolDefinitions ? `const toolDefinitions = ${toolDefinitions};` : ''; - const modelConfigurationLine = generateModelConfigurationLine(config); - const graphqlEndpoint = ctx.api.graphqlUrl; + const TOOL_DEFINITIONS_LINE = toolDefinitions ? `const toolDefinitions = ${toolDefinitions};` : ''; - const toolsConfigurationLine = toolDefinitions + const TOOLS_CONFIGURATION_LINE = toolDefinitions ? dedent`const dataTools = toolDefinitions.tools; const toolsConfiguration = { dataTools, @@ -44,70 +76,7 @@ const createInvokeLambdaRequestFunction = (config: ConversationDirectiveConfigur clientTools };`; - const requestFunctionString = ` - import { util } from '@aws-appsync/utils'; - - export function request(ctx) { - const { args, identity, request, prev } = ctx; - ${toolDefinitionsLine} - const selectionSet = '${selectionSet}'; - const graphqlApiEndpoint = '${graphqlEndpoint}'; - - const messages = prev.result.items; - const responseMutation = { - name: '${responseMutationName}', - inputTypeName: '${responseMutationInputTypeName}', - selectionSet, - }; - const currentMessageId = ctx.stash.defaultValues.id; - ${modelConfigurationLine} - - const clientTools = args.toolConfiguration?.tools?.map((tool) => { return { ...tool.toolSpec }}); - ${toolsConfigurationLine} - - const authHeader = request.headers['authorization']; - const payload = { - conversationId: args.conversationId, - currentMessageId, - responseMutation, - graphqlApiEndpoint, - modelConfiguration, - request: { headers: { authorization: authHeader }}, - messages, - toolsConfiguration, - }; - - return { - operation: 'Invoke', - payload, - invocationType: 'Event' - }; - }`; - - return requestFunctionString; -}; - -const createInvokeLambdaResponseFunction = (config: ConversationDirectiveConfiguration): string => { - const responseFunctionString = ` -export function response(ctx) { - let success = true; - if (ctx.error) { - util.appendError(ctx.error.message, ctx.error.type); - success = false; - } - const response = { - __typename: '${config.messageModel.messageModel.name.value}', - id: ctx.stash.defaultValues.id, - conversationId: ctx.args.conversationId, - role: 'user', - content: ctx.args.content, - createdAt: ctx.stash.defaultValues.createdAt, - updatedAt: ctx.stash.defaultValues.updatedAt, - }; - return response; -}`; - - return responseFunctionString; + return { TOOL_DEFINITIONS_LINE, TOOLS_CONFIGURATION_LINE }; }; /** diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver-fn.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver-fn.js new file mode 100644 index 0000000000..bea3d7a53e --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver-fn.js @@ -0,0 +1,8 @@ +export function request(ctx) { + ctx.stash.metadata.index = 'gsi-ConversationMessage.conversationId.createdAt'; + return {}; +} + +export function response(ctx) { + return {}; +} \ No newline at end of file diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver.ts new file mode 100644 index 0000000000..259f13b270 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-init-resolver.ts @@ -0,0 +1,16 @@ +import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; +import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; + +/** + * Creates and returns the function code for the list messages resolver init slot. + * + * @returns {MappingTemplateProvider} + */ +export const listMessageInitMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'list-messages-init-resolver-fn.js'), 'utf8'); + const templateName = `Query.${config.field.name.value}.list-message-init.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); +}; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-post-data-load-resolver-fn.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-post-data-load-resolver-fn.js new file mode 100644 index 0000000000..a5b03acb6f --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-post-data-load-resolver-fn.js @@ -0,0 +1,30 @@ +export function request(ctx) { + return {}; +} + +export function response(ctx) { + const items = ctx.prev.result.items.reduce((acc, item) => { + const userMessage = { + ...item, + role: "user", + updatedAt: item.createdAt + }; + delete userMessage.assistantContent; + acc.push(userMessage); + + if (item.assistantContent) { + const assistantMessage = { + ...item, + role: "assistant", + content: item.assistantContent, + createdAt: item.updatedAt, + }; + delete assistantMessage.assistantContent; + acc.push(assistantMessage); + } + + return acc; + }, []); + + return { ...ctx.prev.result, items }; +} \ No newline at end of file diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-post-data-load-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-post-data-load-resolver.ts new file mode 100644 index 0000000000..b4e95496c4 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/list-messages-post-data-load-resolver.ts @@ -0,0 +1,16 @@ +import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; +import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; + +/** + * Creates and returns the function code for the list messages resolver postDataLoad slot. + * + * @returns {MappingTemplateProvider} + */ +export const listMessagePostDataLoadMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'list-messages-post-data-load-resolver-fn.js'), 'utf8'); + const templateName = `Query.${config.field.name.value}.list-messages-post-data-load.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); +}; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver-fn.template.js new file mode 100644 index 0000000000..082b5cac0d --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver-fn.template.js @@ -0,0 +1,49 @@ +export function request(ctx) { + const { conversationId } = ctx.args; + const { authFilter } = ctx.stash; + + const limit = 100; + const query = { + expression: 'conversationId = :conversationId', + expressionValues: util.dynamodb.toMapValues({ + ':conversationId': conversationId, + }), + }; + + const filter = JSON.parse(util.transform.toDynamoDBFilterExpression(authFilter)); + const index = 'gsi-ConversationMessage.conversationId.createdAt'; + + return { + operation: 'Query', + query, + filter, + index, + scanIndexForward: false, + }; +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const messagesWithAssistantResponse = ctx.result.items + .filter((message) => message.assistantContent !== undefined) + .reduce((acc, current) => { + const { content, assistantContent, aiContext } = current; + const userContent = aiContext + ? [...content, { text: JSON.stringify(aiContext) }] + : content; + + acc.push({ role: 'user', content: userContent }); + acc.push({ role: 'assistant', content: assistantContent }); + return acc; + }, []); + + const { content, aiContext } = ctx.prev.result; + const currentUserMessageContent = aiContext + ? [...content, { text: JSON.stringify(aiContext) }] + : content; + const currentMessage = { role: 'user', content: currentUserMessageContent }; + const items = [...messagesWithAssistantResponse, currentMessage]; + return { items }; +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts index a6f604bbd2..ce56cfa558 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/message-history-resolver.ts @@ -1,77 +1,16 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; /** * Creates a mapping template for reading message history in a conversation. * * @returns {MappingTemplateProvider} An object containing request and response mapping functions. */ -export const readHistoryMappingTemplate = (): MappingTemplateProvider => { - // TODO: filter to only retrieve messages that have an assistant response. - const req = createMessageHistoryRequestFunction(); - const res = createMessageHistoryResponseFunction(); - - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates a request mapping template for reading message history in a conversation. - * - * @returns {MappingTemplateProvider} A mapping template provider for the request function. - */ -const createMessageHistoryRequestFunction = (): string => { - const requestFunctionString = ` - export function request(ctx) { - const { conversationId } = ctx.args; - const { authFilter } = ctx.stash; - - const limit = 100; - const query = { - expression: 'conversationId = :conversationId', - expressionValues: util.dynamodb.toMapValues({ - ':conversationId': ctx.args.conversationId - }) - }; - - const filter = JSON.parse(util.transform.toDynamoDBFilterExpression(authFilter)); - const index = 'gsi-ConversationMessage.conversationId.createdAt'; - - return { - operation: 'Query', - query, - filter, - index, - scanIndexForward: false, - } - }`; - - return requestFunctionString; -}; - -/** - * Creates a response mapping template for reading message history in a conversation. - * - * @returns {MappingTemplateProvider} A mapping template provider for the response function. - */ -const createMessageHistoryResponseFunction = (): string => { - const responseFunctionString = ` - export function response(ctx) { - if (ctx.error) { - util.error(ctx.error.message, ctx.error.type); - } - const messagesWithAssistantResponse = ctx.result.items - .filter((message) => message.assistantContent !== undefined) - .reduce((acc, current) => { - acc.push({ role: 'user', content: current.content }); - acc.push({ role: 'assistant', content: current.assistantContent }); - return acc; - }, []) - - const currentMessage = { role: 'user', content: ctx.prev.result.content }; - const items = [...messagesWithAssistantResponse, currentMessage]; - return { items }; - }`; - - return responseFunctionString; +export const readHistoryMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'message-history-resolver-fn.template.js'), 'utf8'); + const templateName = `Mutation.${config.field.name.value}.message-history.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js new file mode 100644 index 0000000000..7a05ff10b7 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver-fn.template.js @@ -0,0 +1,30 @@ +export function request(ctx) { + const { authFilter } = ctx.stash; + + const query = { + expression: 'id = :id', + expressionValues: util.dynamodb.toMapValues({ + ':id': ctx.args.conversationId, + }), + }; + + const filter = JSON.parse(util.transform.toDynamoDBFilterExpression(authFilter)); + + return { + operation: 'Query', + query, + filter, + }; +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + + if (ctx.result.items.length !== 0) { + return ctx.result.items[0]; + } + + util.error('Conversation not found', 'ResourceNotFound'); +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts index 3b1f1ee3ff..0624e1c9d5 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/verify-session-owner-resolver.ts @@ -1,66 +1,16 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; +import { ConversationDirectiveConfiguration } from '../grapqhl-conversation-transformer'; /** * Creates a mapping template for verifying the session owner in a conversation. * * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ -export const verifySessionOwnerMappingTemplate = (): MappingTemplateProvider => { - const req = createVerifySessionOwnerRequestFunction(); - const res = createVerifySessionOwnerResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for verifying the session owner in a conversation. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ -const createVerifySessionOwnerRequestFunction = (): string => { - const requestFunctionString = ` - export function request(ctx) { - const { authFilter } = ctx.stash; - - const query = { - expression: 'id = :id', - expressionValues: util.dynamodb.toMapValues({ - ':id': ctx.args.conversationId - }) - }; - - const filter = JSON.parse(util.transform.toDynamoDBFilterExpression(authFilter)); - - return { - operation: 'Query', - query, - filter - }; - } - `; - - return requestFunctionString; -}; - -/** - * Creates the response function for verifying the session owner in a conversation. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ -const createVerifySessionOwnerResponseFunction = (): string => { - const responseFunctionString = ` - export function response(ctx) { - if (ctx.error) { - util.error(ctx.error.message, ctx.error.type); - } - - if (ctx.result.items.length !== 0) { - return ctx.result.items[0]; - } - - util.error('Conversation not found', 'ResourceNotFound'); - }`; - - return responseFunctionString; +export const verifySessionOwnerMappingTemplate = (config: ConversationDirectiveConfiguration): MappingTemplateProvider => { + const resolver = fs.readFileSync(path.join(__dirname, 'verify-session-owner-resolver-fn.template.js'), 'utf8'); + const templateName = `Mutation.${config.field.name.value}.verify-session-owner.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver-fn.template.js b/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver-fn.template.js new file mode 100644 index 0000000000..a19aeb0d57 --- /dev/null +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver-fn.template.js @@ -0,0 +1,24 @@ +import { util } from '@aws-appsync/utils'; +import * as ddb from '@aws-appsync/utils/dynamodb'; + +export function request(ctx) { + const args = ctx.stash.transformedArgs ?? ctx.args; + const defaultValues = ctx.stash.defaultValues ?? {}; + const message = { + __typename: '[[CONVERSATION_MESSAGE_TYPE_NAME]]', + role: 'user', + ...args, + ...defaultValues, + }; + const id = ctx.stash.defaultValues.id; + + return ddb.put({ key: { id }, item: message }); +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } else { + return ctx.result; + } +} diff --git a/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts b/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts index a4d1aded1b..cf9848dc01 100644 --- a/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts +++ b/packages/amplify-graphql-conversation-transformer/src/resolvers/write-message-to-table-resolver.ts @@ -1,63 +1,23 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; +import fs from 'fs'; +import path from 'path'; /** * Creates a mapping template for writing a message to a table in a conversation. * - * @param {string} fieldName - The name of the field to write to the table. * @returns {MappingTemplateProvider} An object containing request and response MappingTemplateProviders. */ export const writeMessageToTableMappingTemplate = (fieldName: string): MappingTemplateProvider => { - const req = createWriteMessageToTableRequestFunction(fieldName); - const res = createWriteMessageToTableResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for writing a message to a table in a conversation. - * - * @param {string} fieldName - The name of the field to write to the table. - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ -const createWriteMessageToTableRequestFunction = (fieldName: string): string => { - const requestFunctionString = ` - import { util } from '@aws-appsync/utils' - import * as ddb from '@aws-appsync/utils/dynamodb' - - export function request(ctx) { - const args = ctx.stash.transformedArgs ?? ctx.args; - const defaultValues = ctx.stash.defaultValues ?? {}; - const message = { - __typename: 'ConversationMessage${fieldName}', - role: 'user', - ...args, - ...defaultValues, - }; - const id = ctx.stash.defaultValues.id; - - return ddb.put({ key: { id }, item: message }); - } - `; - - return requestFunctionString; -}; - -/** - * Creates the response function for writing a message to a table in a conversation. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ -const createWriteMessageToTableResponseFunction = (): string => { - const responseFunctionString = ` - export function response(ctx) { - if (ctx.error) { - util.error(ctx.error.message, ctx.error.type); - } else { - return ctx.result; - } - } - `; + const substitutions = { + CONVERSATION_MESSAGE_TYPE_NAME: `ConversationMessage${fieldName}`, + }; + let resolver = fs.readFileSync(path.join(__dirname, 'write-message-to-table-resolver-fn.template.js'), 'utf8'); + Object.entries(substitutions).forEach(([key, value]) => { + const replaced = resolver.replace(new RegExp(`\\[\\[${key}\\]\\]`, 'g'), value); + resolver = replaced; + }); - return responseFunctionString; + const templateName = `Mutation.${fieldName}.write-message-to-table.js`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; diff --git a/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts b/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts index 3d14cda69c..07faa52151 100644 --- a/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts +++ b/packages/amplify-graphql-conversation-transformer/src/transformer-steps/conversation-resolver-generator.ts @@ -16,6 +16,9 @@ import { invokeLambdaMappingTemplate } from '../resolvers/invoke-lambda-resolver import { assistantMutationResolver } from '../resolvers/assistant-mutation-resolver'; import { conversationMessageSubscriptionMappingTamplate } from '../resolvers/assistant-messages-subscription-resolver'; import { overrideIndexAtCfnLevel } from '@aws-amplify/graphql-index-transformer'; +import pluralize from 'pluralize'; +import { listMessagePostDataLoadMappingTemplate } from '../resolvers/list-messages-post-data-load-resolver'; +import { listMessageInitMappingTemplate } from '../resolvers/list-messages-init-resolver'; type KeyAttributeDefinition = { name: string; @@ -28,6 +31,7 @@ export class ConversationResolverGenerator { for (const directive of directives) { this.processToolsForDirective(directive, ctx); this.generateResolversForDirective(directive, ctx); + this.addPostProcessingSlotToListMessagesPipeline(ctx, directive); } } @@ -48,14 +52,22 @@ export class ConversationResolverGenerator { const { functionDataSourceId, referencedFunction } = this.setupFunctionDataSource(directive, functionStack, capitalizedFieldName); this.createAssistantResponseResolver(ctx, directive, capitalizedFieldName); - this.createAssistantResponseSubscriptionResolver(ctx, capitalizedFieldName); + this.createAssistantResponseSubscriptionResolver(ctx, directive, capitalizedFieldName); const functionDataSource = this.addLambdaDataSource(ctx, functionDataSourceId, referencedFunction, capitalizedFieldName); const invokeLambdaFunction = invokeLambdaMappingTemplate(directive, ctx); this.setupMessageTableIndex(ctx, directive); - this.createConversationPipelineResolver(ctx, parentName, fieldName, capitalizedFieldName, functionDataSource, invokeLambdaFunction); + this.createConversationPipelineResolver( + ctx, + parentName, + fieldName, + capitalizedFieldName, + functionDataSource, + invokeLambdaFunction, + directive, + ); } /** @@ -161,23 +173,22 @@ export class ConversationResolverGenerator { capitalizedFieldName: string, functionDataSource: any, invokeLambdaFunction: MappingTemplateProvider, + directive: ConversationDirectiveConfiguration, ): void { const resolverResourceId = ResolverResourceIDs.ResolverResourceID(parentName, fieldName); - const mappingTemplate = { - codeMappingTemplate: invokeLambdaFunction, - }; + const runtime = APPSYNC_JS_RUNTIME; const conversationPipelineResolver = new TransformerResolver( parentName, fieldName, resolverResourceId, - mappingTemplate, + { codeMappingTemplate: invokeLambdaFunction }, ['init', 'auth', 'verifySessionOwner', 'writeMessageToTable', 'retrieveMessageHistory'], ['handleLambdaResponse', 'finish'], functionDataSource, - APPSYNC_JS_RUNTIME, + runtime, ); - this.addPipelineResolverFunctions(ctx, conversationPipelineResolver, capitalizedFieldName); + this.addPipelineResolverFunctions(ctx, conversationPipelineResolver, capitalizedFieldName, directive); ctx.resolvers.addResolver(parentName, fieldName, conversationPipelineResolver); } @@ -189,17 +200,22 @@ export class ConversationResolverGenerator { * @param capitalizedFieldName - The capitalized field name * @param runtime - The runtime configuration */ - private addPipelineResolverFunctions(ctx: TransformerContextProvider, resolver: TransformerResolver, capitalizedFieldName: string): void { + private addPipelineResolverFunctions( + ctx: TransformerContextProvider, + resolver: TransformerResolver, + capitalizedFieldName: string, + directive: ConversationDirectiveConfiguration, + ): void { // Add init function - const initFunction = initMappingTemplate(); + const initFunction = initMappingTemplate(directive); resolver.addJsFunctionToSlot('init', initFunction); // Add auth function - const authFunction = authMappingTemplate(); + const authFunction = authMappingTemplate(directive); resolver.addJsFunctionToSlot('auth', authFunction); // Add verifySessionOwner function - const verifySessionOwnerFunction = verifySessionOwnerMappingTemplate(); + const verifySessionOwnerFunction = verifySessionOwnerMappingTemplate(directive); const sessionModelName = `Conversation${capitalizedFieldName}`; const sessionModelDDBDataSourceName = getModelDataSourceNameForTypeName(ctx, sessionModelName); const conversationSessionDDBDataSource = ctx.api.host.getDataSource(sessionModelDDBDataSourceName); @@ -213,7 +229,7 @@ export class ConversationResolverGenerator { resolver.addJsFunctionToSlot('writeMessageToTable', writeMessageToTableFunction, messageDDBDataSource as any); // Add retrieveMessageHistory function - const retrieveMessageHistoryFunction = readHistoryMappingTemplate(); + const retrieveMessageHistoryFunction = readHistoryMappingTemplate(directive); resolver.addJsFunctionToSlot('retrieveMessageHistory', retrieveMessageHistoryFunction, messageDDBDataSource as any); } @@ -229,18 +245,14 @@ export class ConversationResolverGenerator { capitalizedFieldName: string, ): void { const assistantResponseResolverResourceId = ResolverResourceIDs.ResolverResourceID('Mutation', directive.responseMutationName); - const assistantResponseResolverFunction = assistantMutationResolver(); + const assistantResponseResolverFunction = assistantMutationResolver(directive); const conversationMessageDataSourceName = getModelDataSourceNameForTypeName(ctx, `ConversationMessage${capitalizedFieldName}`); const conversationMessageDataSource = ctx.api.host.getDataSource(conversationMessageDataSourceName); - - const mappingTemplate = { - codeMappingTemplate: assistantResponseResolverFunction, - }; const assistantResponseResolver = new TransformerResolver( 'Mutation', directive.responseMutationName, assistantResponseResolverResourceId, - mappingTemplate, + { codeMappingTemplate: assistantResponseResolverFunction }, [], [], conversationMessageDataSource as any, @@ -255,13 +267,17 @@ export class ConversationResolverGenerator { * @param ctx - The transformer context provider * @param capitalizedFieldName - The capitalized field name */ - private createAssistantResponseSubscriptionResolver(ctx: TransformerContextProvider, capitalizedFieldName: string): void { + private createAssistantResponseSubscriptionResolver( + ctx: TransformerContextProvider, + directive: ConversationDirectiveConfiguration, + capitalizedFieldName: string, + ): void { const onAssistantResponseSubscriptionFieldName = `onCreateAssistantResponse${capitalizedFieldName}`; const onAssistantResponseSubscriptionResolverResourceId = ResolverResourceIDs.ResolverResourceID( 'Subscription', onAssistantResponseSubscriptionFieldName, ); - const onAssistantResponseSubscriptionResolverFunction = conversationMessageSubscriptionMappingTamplate(); + const onAssistantResponseSubscriptionResolverFunction = conversationMessageSubscriptionMappingTamplate(directive); const mappingTemplate = { codeMappingTemplate: onAssistantResponseSubscriptionResolverFunction, @@ -301,6 +317,23 @@ export class ConversationResolverGenerator { return ctx.api.host.addLambdaDataSource(functionDataSourceId, referencedFunction, {}, functionDataSourceScope); } + private addPostProcessingSlotToListMessagesPipeline( + ctx: TransformerContextProvider, + directive: ConversationDirectiveConfiguration, + ): void { + const messageModelName = directive.messageModel.messageModel.name.value; + const pluralized = pluralize(messageModelName); + const listMessagesResolver = ctx.resolvers.getResolver('Query', `list${pluralized}`) as TransformerResolver; + + const listMessagePostDataLoadFunction = listMessagePostDataLoadMappingTemplate(directive); + const initResolverFn = listMessageInitMappingTemplate(directive); + + if (listMessagesResolver) { + listMessagesResolver.addJsFunctionToSlot('postDataLoad', listMessagePostDataLoadFunction); + listMessagesResolver.addJsFunctionToSlot('init', initResolverFn); + } + } + /** * Sets up the message table index * @param ctx - The transformer context provider diff --git a/packages/amplify-graphql-generation-transformer/package.json b/packages/amplify-graphql-generation-transformer/package.json index c24201bce7..17d44e43e6 100644 --- a/packages/amplify-graphql-generation-transformer/package.json +++ b/packages/amplify-graphql-generation-transformer/package.json @@ -16,9 +16,10 @@ "access": "public" }, "scripts": { - "build": "tsc", + "build": "tsc && yarn copy-js-resolver-templates", "watch": "tsc -w", "clean": "rimraf ./lib", + "copy-js-resolver-templates": "cp ./src/resolvers/*.template.js ./lib/resolvers", "test": "jest", "extract-api": "ts-node ../../scripts/extract-api.ts" }, diff --git a/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap b/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap index e9ccee8c7a..d6dbbbdc2d 100644 --- a/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap +++ b/packages/amplify-graphql-generation-transformer/src/__tests__/__snapshots__/amplify-graphql-generation-transformer.test.ts.snap @@ -56,10 +56,12 @@ export const response = (ctx) => { `; exports[`generation route all scalar types 2`] = ` -"export function request(ctx) { +{ + "makeBox-invoke-bedrock-fn": "export function request(ctx) { const toolConfig = {"tools":[{"toolSpec":{"name":"responseType","description":"Generate a response type for the given field","inputSchema":{"json":{"type":"object","properties":{"value":{"type":"object","properties":{"int":{"type":"number","description":"A signed 32-bit integer value."},"float":{"type":"number","description":"An IEEE 754 floating point value."},"string":{"type":"string","description":"A UTF-8 character sequence."},"id":{"type":"string","description":"A unique identifier for an object. This scalar is serialized like a String but isn't meant to be human-readable."},"boolean":{"type":"boolean","description":"A boolean value."},"awsjson":{"type":"string","description":"A JSON string. Any valid JSON construct is automatically parsed and loaded in the resolver code as maps, lists, or scalar values rather than as the literal input strings. Unquoted strings or otherwise invalid JSON result in a GraphQL validation error."},"awsemail":{"type":"string","description":"An email address in the format local-part@domain-part as defined by RFC 822.","pattern":"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\\\\.[a-zA-Z]{2,}$"},"awsdate":{"type":"string","description":"An extended ISO 8601 date string in the format YYYY-MM-DD.","pattern":"^\\\\d{4}-d{2}-d{2}$"},"awstime":{"type":"string","description":"An extended ISO 8601 time string in the format hh:mm:ss.sss.","pattern":"^\\\\d{2}:\\\\d{2}:\\\\d{2}\\\\.\\\\d{3}$"},"awsdatetime":{"type":"string","description":"An extended ISO 8601 date and time string in the format YYYY-MM-DDThh:mm:ss.sssZ.","pattern":"^\\\\d{4}-\\\\d{2}-\\\\d{2}T\\\\d{2}:\\\\d{2}:\\\\d{2}\\\\.\\\\d{3}Z$"},"awstimestamp":{"type":"string","description":"An integer value representing the number of seconds before or after 1970-01-01-T00:00Z.","pattern":"^\\\\d+$"},"awsphone":{"type":"string","description":"A phone number. This value is stored as a string. Phone numbers can contain either spaces or hyphens to separate digit groups. Phone numbers without a country code are assumed to be US/North American numbers adhering to the North American Numbering Plan (NANP).","pattern":"^\\\\d{3}-d{3}-d{4}$"},"awsurl":{"type":"string","description":"A URL as defined by RFC 1738. For example, https://www.amazon.com/dp/B000NZW3KC/ or mailto:example@example.com. URLs must contain a schema (http, mailto) and can't contain two forward slashes (//) in the path part.","pattern":"^(https?|mailto)://[^s/$.?#].[^s]*$"},"awsipaddress":{"type":"string","description":"A valid IPv4 or IPv6 address. IPv4 addresses are expected in quad-dotted notation (123.12.34.56). IPv6 addresses are expected in non-bracketed, colon-separated format (1a2b:3c4b::1234:4567). You can include an optional CIDR suffix (123.45.67.89/16) to indicate subnet mask."}},"required":[]}},"required":["value"]}}}}],"toolChoice":{"tool":{"name":"responseType"}}}; const prompt = ""; const args = JSON.stringify(ctx.args); + const inferenceConfig = undefined; return { resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', @@ -67,16 +69,18 @@ exports[`generation route all scalar types 2`] = ` params: { headers: { 'Content-Type': 'application/json' }, body: { - messages: [{ - role: 'user', - content: [{ text: args }], - }], + messages: [ + { + role: 'user', + content: [{ text: args }], + }, + ], system: [{ text: prompt }], toolConfig, - // default inference config - } - } - } + ...inferenceConfig, + }, + }, + }; } export function response(ctx) { @@ -97,7 +101,9 @@ export function response(ctx) { const response = toolUse.input.value; return response; -}" +} +", +} `; exports[`generation route custom query 1`] = ` @@ -156,10 +162,12 @@ export const response = (ctx) => { `; exports[`generation route custom query 2`] = ` -"export function request(ctx) { +{ + "generateRecipe-invoke-bedrock-fn": "export function request(ctx) { const toolConfig = {"tools":[{"toolSpec":{"name":"responseType","description":"Generate a response type for the given field","inputSchema":{"json":{"type":"object","properties":{"value":{"type":"object","properties":{"name":{"type":"string","description":"A UTF-8 character sequence."},"ingredients":{"type":"array","items":{"type":"string","description":"A UTF-8 character sequence."}},"instructions":{"type":"string","description":"A UTF-8 character sequence."},"meal":{"type":"object","properties":{"Meal":{"type":"string","enum":["BREAKFAST","LUNCH","DINNER"]}},"required":[]}},"required":[]}},"required":["value"]}}}}],"toolChoice":{"tool":{"name":"responseType"}}}; const prompt = "You are a helpful assistant that generates recipes."; const args = JSON.stringify(ctx.args); + const inferenceConfig = undefined; return { resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', @@ -167,16 +175,18 @@ exports[`generation route custom query 2`] = ` params: { headers: { 'Content-Type': 'application/json' }, body: { - messages: [{ - role: 'user', - content: [{ text: args }], - }], + messages: [ + { + role: 'user', + content: [{ text: args }], + }, + ], system: [{ text: prompt }], toolConfig, - // default inference config - } - } - } + ...inferenceConfig, + }, + }, + }; } export function response(ctx) { @@ -197,7 +207,9 @@ export function response(ctx) { const response = toolUse.input.value; return response; -}" +} +", +} `; exports[`generation route model type with null timestamps 1`] = ` @@ -260,6 +272,7 @@ exports[`generation route model type with null timestamps 2`] = ` const toolConfig = {"tools":[{"toolSpec":{"name":"responseType","description":"Generate a response type for the given field","inputSchema":{"json":{"type":"object","properties":{"value":{"type":"object","properties":{"content":{"type":"string","description":"A UTF-8 character sequence."},"isDone":{"type":"boolean","description":"A boolean value."},"id":{"type":"string","description":"A unique identifier for an object. This scalar is serialized like a String but isn't meant to be human-readable."}},"required":["id"]}},"required":["value"]}}}}],"toolChoice":{"tool":{"name":"responseType"}}}; const prompt = "Make a string based on the description."; const args = JSON.stringify(ctx.args); + const inferenceConfig = undefined; return { resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', @@ -267,16 +280,18 @@ exports[`generation route model type with null timestamps 2`] = ` params: { headers: { 'Content-Type': 'application/json' }, body: { - messages: [{ - role: 'user', - content: [{ text: args }], - }], + messages: [ + { + role: 'user', + content: [{ text: args }], + }, + ], system: [{ text: prompt }], toolConfig, - // default inference config - } - } - } + ...inferenceConfig, + }, + }, + }; } export function response(ctx) { @@ -297,7 +312,8 @@ export function response(ctx) { const response = toolUse.input.value; return response; -}" +} +" `; exports[`generation route scalar type 1`] = ` @@ -356,10 +372,29 @@ export const response = (ctx) => { `; exports[`generation route scalar type 2`] = ` -"export function request(ctx) { +{ + "Query.makeTodo.auth.1.req.vtl": "## [Start] Field Authorization Steps. ** +#set( $isAuthorized = false ) +#if( $util.authType() == "IAM Authorization" ) + #if( !$isAuthorized ) + #if( $ctx.identity.userArn == $ctx.stash.unauthRole ) + #set( $isAuthorized = true ) + #end + #end +#end +#if( !$isAuthorized ) +$util.unauthorized() +#end +$util.toJson({"version":"2018-05-29","payload":{}}) +## [End] Field Authorization Steps. **", + "Query.makeTodo.auth.1.res.vtl": "## [Start] Return Source Field. ** +$util.toJson($context.source["makeTodo"]) +## [End] Return Source Field. **", + "makeTodo-invoke-bedrock-fn": "export function request(ctx) { const toolConfig = {"tools":[{"toolSpec":{"name":"responseType","description":"Generate a response type for the given field","inputSchema":{"json":{"type":"object","properties":{"value":{"type":"string","description":"A UTF-8 character sequence."}},"required":["value"]}}}}],"toolChoice":{"tool":{"name":"responseType"}}}; const prompt = "Make a string based on the description."; const args = JSON.stringify(ctx.args); + const inferenceConfig = undefined; return { resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', @@ -367,16 +402,69 @@ exports[`generation route scalar type 2`] = ` params: { headers: { 'Content-Type': 'application/json' }, body: { - messages: [{ - role: 'user', - content: [{ text: args }], - }], + messages: [ + { + role: 'user', + content: [{ text: args }], + }, + ], system: [{ text: prompt }], toolConfig, - // default inference config - } - } + ...inferenceConfig, + }, + }, + }; +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; +} +", +} +`; + +exports[`generation route with valid inference configuration 1`] = ` +{ + "generateWithConfig-invoke-bedrock-fn": "export function request(ctx) { + const toolConfig = {"tools":[{"toolSpec":{"name":"responseType","description":"Generate a response type for the given field","inputSchema":{"json":{"type":"object","properties":{"value":{"type":"string","description":"A UTF-8 character sequence."}},"required":["value"]}}}}],"toolChoice":{"tool":{"name":"responseType"}}}; + const prompt = "Generate a string based on the description."; + const args = JSON.stringify(ctx.args); + const inferenceConfig = { inferenceConfig: {"maxTokens":100,"temperature":0.7,"topP":0.9} },; + + return { + resourcePath: '/model/anthropic.claude-3-haiku-20240307-v1:0/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [ + { + role: 'user', + content: [{ text: args }], + }, + ], + system: [{ text: prompt }], + toolConfig, + ...inferenceConfig, + }, + }, + }; } export function response(ctx) { @@ -397,5 +485,7 @@ export function response(ctx) { const response = toolUse.input.value; return response; -}" +} +", +} `; diff --git a/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts b/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts index 14b33f7940..53cb7b172d 100644 --- a/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts +++ b/packages/amplify-graphql-generation-transformer/src/__tests__/amplify-graphql-generation-transformer.test.ts @@ -53,9 +53,9 @@ test('generation route scalar type', () => { expect(resolverCode).toBeDefined(); expect(resolverCode).toMatchSnapshot(); - const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; - expect(resolverFnCode).toBeDefined(); - expect(resolverFnCode).toMatchSnapshot(); + const resolvers = out.resolvers; + expect(resolvers).toBeDefined(); + expect(resolvers).toMatchSnapshot(); const schema = parse(out.schema); validateModelSchema(schema); @@ -92,9 +92,9 @@ test('generation route custom query', () => { expect(resolverCode).toBeDefined(); expect(resolverCode).toMatchSnapshot(); - const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; - expect(resolverFnCode).toBeDefined(); - expect(resolverFnCode).toMatchSnapshot(); + const resolvers = out.resolvers; + expect(resolvers).toBeDefined(); + expect(resolvers).toMatchSnapshot(); const schema = parse(out.schema); validateModelSchema(schema); @@ -124,9 +124,9 @@ test('generation route model type with null timestamps', () => { expect(resolverCode).toBeDefined(); expect(resolverCode).toMatchSnapshot(); - const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; - expect(resolverFnCode).toBeDefined(); - expect(resolverFnCode).toMatchSnapshot(); + const resolverFn = out.resolvers['makeTodo-invoke-bedrock-fn']; + expect(resolverFn).toBeDefined(); + expect(resolverFn).toMatchSnapshot(); const schema = parse(out.schema); validateModelSchema(schema); @@ -226,9 +226,37 @@ test('generation route all scalar types', () => { expect(resolverCode).toBeDefined(); expect(resolverCode).toMatchSnapshot(); - const resolverFnCode = getResolverFnResource(queryName, out.rootStack.Resources)['Properties']['Code']; - expect(resolverFnCode).toBeDefined(); - expect(resolverFnCode).toMatchSnapshot(); + const resolvers = out.resolvers; + expect(resolvers).toBeDefined(); + expect(resolvers).toMatchSnapshot(); + + const schema = parse(out.schema); + validateModelSchema(schema); +}); + +test('generation route with valid inference configuration', () => { + const queryName = 'generateWithConfig'; + const inputSchema = ` + type Query { + ${queryName}(description: String!): String + @generation( + aiModel: "anthropic.claude-3-haiku-20240307-v1:0", + systemPrompt: "Generate a string based on the description.", + inferenceConfiguration: { + maxTokens: 100, + temperature: 0.7, + topP: 0.9 + } + ) + } + `; + + const out = transform(inputSchema); + expect(out).toBeDefined(); + + const resolvers = out.resolvers; + expect(resolvers).toBeDefined(); + expect(resolvers).toMatchSnapshot(); const schema = parse(out.schema); validateModelSchema(schema); @@ -294,22 +322,6 @@ const getResolverResource = (queryName: string, resources?: Record) return resources?.[resolverName]; }; -const getResolverFnResource = (queryName: string, resources?: Record): Record => { - const capitalizedQueryName = queryName.charAt(0).toUpperCase() + queryName.slice(1); - const resourcePrefix = `Query${capitalizedQueryName}DataResolverFn`; - if (!resources) { - fail('No resources found.'); - } - const resource = Object.entries(resources).find(([key, _]) => { - return key.startsWith(resourcePrefix); - })?.[1]; - - if (!resource) { - fail(`Resource named with prefix ${resourcePrefix} not found.`); - } - return resource; -}; - const defaultAuthConfig: AppSyncAuthConfiguration = { defaultAuthentication: { authenticationType: 'AWS_IAM', diff --git a/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts b/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts index 4114677112..71d6420263 100644 --- a/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts +++ b/packages/amplify-graphql-generation-transformer/src/grapqhl-generation-transformer.ts @@ -122,14 +122,11 @@ export class GenerationTransformer extends TransformerPluginBase { invokeBedrockFunction: MappingTemplateProvider, dataSource: cdk.aws_appsync.HttpDataSource, ): void { - const mappingTemplate = { - codeMappingTemplate: invokeBedrockFunction, - }; const conversationPipelineResolver = new TransformerResolver( parentName, fieldName, resolverResourceId, - mappingTemplate, + { codeMappingTemplate: invokeBedrockFunction }, ['auth'], [], dataSource as any, diff --git a/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock-resolver-fn.template.js b/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock-resolver-fn.template.js new file mode 100644 index 0000000000..bd48071ed6 --- /dev/null +++ b/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock-resolver-fn.template.js @@ -0,0 +1,45 @@ +export function request(ctx) { + const toolConfig = [[TOOL_CONFIG]]; + const prompt = [[SYSTEM_PROMPT]]; + const args = JSON.stringify(ctx.args); + const inferenceConfig = [[INFERENCE_CONFIG]]; + + return { + resourcePath: '/model/[[AI_MODEL]]/converse', + method: 'POST', + params: { + headers: { 'Content-Type': 'application/json' }, + body: { + messages: [ + { + role: 'user', + content: [{ text: args }], + }, + ], + system: [{ text: prompt }], + toolConfig, + ...inferenceConfig, + }, + }, + }; +} + +export function response(ctx) { + if (ctx.error) { + util.error(ctx.error.message, ctx.error.type); + } + const body = JSON.parse(ctx.result.body); + const { content } = body.output.message; + + if (content.length < 1) { + util.error('No content block in assistant response.', 'error'); + } + + const toolUse = content[0].toolUse; + if (!toolUse) { + util.error('Missing tool use block in assistant response.', 'error'); + } + + const response = toolUse.input.value; + return response; +} diff --git a/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts b/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts index 02a5134a40..0a56488511 100644 --- a/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts +++ b/packages/amplify-graphql-generation-transformer/src/resolvers/invoke-bedrock.ts @@ -1,7 +1,8 @@ import { MappingTemplate } from '@aws-amplify/graphql-transformer-core'; import { MappingTemplateProvider } from '@aws-amplify/graphql-transformer-interfaces'; -import { dedent } from 'ts-dedent'; import { GenerationConfigurationWithToolConfig, InferenceConfiguration } from '../grapqhl-generation-transformer'; +import fs from 'fs'; +import path from 'path'; /** * Creates the resolver functions for invoking Amazon Bedrock. @@ -9,82 +10,31 @@ import { GenerationConfigurationWithToolConfig, InferenceConfiguration } from '. * @param {GenerationConfigurationWithToolConfig} config - The configuration object containing AI model details, tool config, and inference settings. * @returns {Object} An object containing request and response resolver functions. */ - export const createInvokeBedrockResolverFunction = (config: GenerationConfigurationWithToolConfig): MappingTemplateProvider => { - const req = createInvokeBedrockRequestFunction(config); - const res = createInvokeBedrockResponseFunction(); - return MappingTemplate.inlineTemplateFromString(dedent(req + '\n' + res)); -}; - -/** - * Creates the request function for the Bedrock resolver. - * - * @param {GenerationConfigurationWithToolConfig} config - The configuration object for the resolver. - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the request function. - */ -const createInvokeBedrockRequestFunction = (config: GenerationConfigurationWithToolConfig): string => { - const { aiModel, toolConfig, inferenceConfiguration } = config; - const stringifiedToolConfig = JSON.stringify(toolConfig); - const stringifiedSystemPrompt = JSON.stringify(config.systemPrompt); - // TODO: add stopReason: max_tokens error handling - const inferenceConfig = getInferenceConfigResolverDefinition(inferenceConfiguration); - const requestFunctionString = ` - export function request(ctx) { - const toolConfig = ${stringifiedToolConfig}; - const prompt = ${stringifiedSystemPrompt}; - const args = JSON.stringify(ctx.args); - - return { - resourcePath: '/model/${aiModel}/converse', - method: 'POST', - params: { - headers: { 'Content-Type': 'application/json' }, - body: { - messages: [{ - role: 'user', - content: [{ text: args }], - }], - system: [{ text: prompt }], - toolConfig, - ${inferenceConfig} - } - } - } - }`; - - return requestFunctionString; + const { aiModel, toolConfig, inferenceConfiguration, field } = config; + const AI_MODEL = aiModel; + const TOOL_CONFIG = JSON.stringify(toolConfig); + const SYSTEM_PROMPT = JSON.stringify(config.systemPrompt); + const INFERENCE_CONFIG = getInferenceConfigResolverDefinition(inferenceConfiguration); + + const resolver = generateResolver('invoke-bedrock-resolver-fn.template.js', { + AI_MODEL, + TOOL_CONFIG, + SYSTEM_PROMPT, + INFERENCE_CONFIG, + }); + + const templateName = `${field.name.value}-invoke-bedrock-fn`; + return MappingTemplate.s3MappingFunctionCodeFromString(resolver, templateName); }; -/** - * Creates the response function for the Bedrock resolver. - * - * @returns {MappingTemplateProvider} A MappingTemplateProvider for the response function. - */ -const createInvokeBedrockResponseFunction = (): string => { - // TODO: add stopReason: max_tokens error handling - const responseFunctionString = ` - export function response(ctx) { - if (ctx.error) { - util.error(ctx.error.message, ctx.error.type); - } - const body = JSON.parse(ctx.result.body); - const { content } = body.output.message; - - if (content.length < 1) { - util.error('No content block in assistant response.', 'error'); - } - - const toolUse = content[0].toolUse; - if (!toolUse) { - util.error('Missing tool use block in assistant response.', 'error'); - } - - const response = toolUse.input.value; - return response; - } -`; - - return responseFunctionString; +const generateResolver = (fileName: string, values: Record): string => { + let resolver = fs.readFileSync(path.join(__dirname, fileName), 'utf8'); + Object.entries(values).forEach(([key, value]) => { + const replaced = resolver.replace(new RegExp(`\\[\\[${key}\\]\\]`, 'g'), value); + resolver = replaced; + }); + return resolver; }; /** @@ -95,6 +45,6 @@ const createInvokeBedrockResponseFunction = (): string => { */ const getInferenceConfigResolverDefinition = (inferenceConfiguration?: InferenceConfiguration): string => { return inferenceConfiguration && Object.keys(inferenceConfiguration).length > 0 - ? `inferenceConfig: ${JSON.stringify(inferenceConfiguration)},` - : '// default inference config'; + ? `{ inferenceConfig: ${JSON.stringify(inferenceConfiguration)} },` + : 'undefined'; }; diff --git a/packages/amplify-graphql-model-transformer/src/__tests__/__snapshots__/model-transformer.test.ts.snap b/packages/amplify-graphql-model-transformer/src/__tests__/__snapshots__/model-transformer.test.ts.snap index c676b2fcd3..0115ca1500 100644 --- a/packages/amplify-graphql-model-transformer/src/__tests__/__snapshots__/model-transformer.test.ts.snap +++ b/packages/amplify-graphql-model-transformer/src/__tests__/__snapshots__/model-transformer.test.ts.snap @@ -2579,7 +2579,7 @@ $util.toJson({}) $util.qr($ListRequest.put(\\"operation\\", \\"Scan\\")) #end #if( !$util.isNull($ctx.stash.metadata.index) ) - #set( $ListRequest.IndexName = $ctx.stash.metadata.index ) + #set( $ListRequest.index = $ctx.stash.metadata.index ) #end $util.toJson($ListRequest) ## [End] List Request. **", @@ -3166,7 +3166,7 @@ $util.toJson({}) $util.qr($ListRequest.put(\\"operation\\", \\"Scan\\")) #end #if( !$util.isNull($ctx.stash.metadata.index) ) - #set( $ListRequest.IndexName = $ctx.stash.metadata.index ) + #set( $ListRequest.index = $ctx.stash.metadata.index ) #end $util.toJson($ListRequest) ## [End] List Request. **", @@ -3753,7 +3753,7 @@ $util.toJson({}) $util.qr($ListRequest.put(\\"operation\\", \\"Scan\\")) #end #if( !$util.isNull($ctx.stash.metadata.index) ) - #set( $ListRequest.IndexName = $ctx.stash.metadata.index ) + #set( $ListRequest.index = $ctx.stash.metadata.index ) #end $util.toJson($ListRequest) ## [End] List Request. **", diff --git a/packages/amplify-graphql-model-transformer/src/resolvers/dynamodb/query.ts b/packages/amplify-graphql-model-transformer/src/resolvers/dynamodb/query.ts index 981281c39e..fc153e6894 100644 --- a/packages/amplify-graphql-model-transformer/src/resolvers/dynamodb/query.ts +++ b/packages/amplify-graphql-model-transformer/src/resolvers/dynamodb/query.ts @@ -160,7 +160,7 @@ export const generateListRequestTemplate = (): string => { ]), qref(methodCall(ref(`${requestVariable}.put`), str('operation'), str('Scan'))), ), - iff(not(methodCall(ref('util.isNull'), ref(indexNameVariable))), set(ref(`${requestVariable}.IndexName`), ref(indexNameVariable))), + iff(not(methodCall(ref('util.isNull'), ref(indexNameVariable))), set(ref(`${requestVariable}.index`), ref(indexNameVariable))), toJson(ref(requestVariable)), ]); return printBlock('List Request')(expression); diff --git a/yarn.lock b/yarn.lock index d2f12b7977..3b14bfe9f8 100644 --- a/yarn.lock +++ b/yarn.lock @@ -10,10 +10,10 @@ "@jridgewell/gen-mapping" "^0.3.5" "@jridgewell/trace-mapping" "^0.3.24" -"@aws-amplify/ai-constructs@^0.1.4": - version "0.1.4" - resolved "https://registry.npmjs.org/@aws-amplify/ai-constructs/-/ai-constructs-0.1.4.tgz#043ca7793cb4a97ad7864797bd70dbfa323329f4" - integrity sha512-BGLBFs/pt6JrNgUo+QD0Szt/ssHMa6EyEE45yLoHemwPHRuJPpnFmxIbbxgxaqJP0mWK6QMs9Wh3IsdJ/6XhDA== +"@aws-amplify/ai-constructs@0.0.0-test-20240925144932": + version "0.0.0-test-20240925144932" + resolved "https://registry.yarnpkg.com/@aws-amplify/ai-constructs/-/ai-constructs-0.0.0-test-20240925144932.tgz#4b7fd94e08f18f0ccc96f44b02d30bd87b470647" + integrity sha512-F63OomSxZVtY3WY4akD2XYXKMSda+jBlC6/B3IAf9F9Vi1+0RSn/iwQnRwlIMRby3X7r9z0bVk7h0ey1AkH42g== dependencies: "@aws-amplify/plugin-types" "^1.0.1" "@aws-sdk/client-bedrock-runtime" "^3.622.0"