Skip to content

Feature/Add Neo4j GraphRag support #3686

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 23, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions packages/components/credentials/Neo4jApi.credential.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { INodeParams, INodeCredential } from '../src/Interface'

class Neo4jApi implements INodeCredential {
label: string
name: string
version: number
description: string
inputs: INodeParams[]

constructor() {
this.label = 'Neo4j API'
this.name = 'neo4jApi'
this.version = 1.0
this.description =
'Refer to <a target="_blank" href="https://neo4j.com/docs/operations-manual/current/authentication-authorization/">official guide</a> on Neo4j authentication'
this.inputs = [
{
label: 'Neo4j URL',
name: 'url',
type: 'string',
description: 'Your Neo4j instance URL (e.g., neo4j://localhost:7687)'
},
{
label: 'Username',
name: 'username',
type: 'string',
description: 'Neo4j database username'
},
{
label: 'Password',
name: 'password',
type: 'password',
description: 'Neo4j database password'
}
]
}
}

module.exports = { credClass: Neo4jApi }
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface'
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher'
import { getBaseClasses } from '../../../src/utils'
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts'
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler'
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console'
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation'
import { formatResponse } from '../../outputparsers/OutputParserHelpers'

class GraphCypherQA_Chain implements INode {
label: string
name: string
version: number
type: string
icon: string
category: string
description: string
baseClasses: string[]
inputs: INodeParams[]
sessionId?: string
outputs: INodeOutputsValue[]

constructor(fields?: { sessionId?: string }) {
this.label = 'Graph Cypher QA Chain'
this.name = 'graphCypherQAChain'
this.version = 1.0
this.type = 'GraphCypherQAChain'
this.icon = 'graphqa.svg'
this.category = 'Chains'
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements'
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)]
this.sessionId = fields?.sessionId
this.inputs = [
{
label: 'Language Model',
name: 'model',
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries and answers.'
},
{
label: 'Neo4j Graph',
name: 'graph',
type: 'Neo4j'
},
{
label: 'Cypher Generation Prompt',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can have all of these prompts as optional.

in future, we will be adding conditional input

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice, Cypher Generation Prompt now will be optional

name: 'cypherPrompt',
optional: true,
type: 'BasePromptTemplate',
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables'
},
{
label: 'Cypher Generation Model',
name: 'cypherModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating Cypher queries. If not provided, the main model will be used.'
},
{
label: 'QA Prompt',
name: 'qaPrompt',
optional: true,
type: 'BasePromptTemplate',
description: 'Prompt template for generating answers. Must include {context} and {question} variables'
},
{
label: 'QA Model',
name: 'qaModel',
optional: true,
type: 'BaseLanguageModel',
description: 'Model for generating answers. If not provided, the main model will be used.'
},
{
label: 'Input Moderation',
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model',
name: 'inputModeration',
type: 'Moderation',
optional: true,
list: true
},
{
label: 'Return Direct',
name: 'returnDirect',
type: 'boolean',
default: false,
optional: true,
description: 'If true, return the raw query results instead of using the QA chain'
}
]
this.outputs = [
{
label: 'Graph Cypher QA Chain',
name: 'graphCypherQAChain',
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)]
},
{
label: 'Output Prediction',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we allow this chain to output a string or json, but we dont have any function that checks for the output

You can reference LLMChain

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took RetrievalQAChain as a reference and removed Output Prediction

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added it back. Works now!

name: 'outputPrediction',
baseClasses: ['string', 'json']
}
]
}

async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> {
const model = nodeData.inputs?.model
const cypherModel = nodeData.inputs?.cypherModel
const qaModel = nodeData.inputs?.qaModel
const graph = nodeData.inputs?.graph
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined
const returnDirect = nodeData.inputs?.returnDirect as boolean
const output = nodeData.outputs?.output as string

// Handle prompt values if they exist
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined
let qaPromptTemplate: PromptTemplate | undefined

if (cypherPrompt) {
if (cypherPrompt instanceof PromptTemplate) {
cypherPromptTemplate = new PromptTemplate({
template: cypherPrompt.template as string,
inputVariables: cypherPrompt.inputVariables
})
if (!qaPrompt) {
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template')
}
} else if (cypherPrompt instanceof FewShotPromptTemplate) {
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate
cypherPromptTemplate = new FewShotPromptTemplate({
examples: cypherPrompt.examples,
examplePrompt: examplePrompt,
inputVariables: cypherPrompt.inputVariables,
prefix: cypherPrompt.prefix,
suffix: cypherPrompt.suffix,
exampleSeparator: cypherPrompt.exampleSeparator,
templateFormat: cypherPrompt.templateFormat
})
} else {
cypherPromptTemplate = cypherPrompt as PromptTemplate
}
}

if (qaPrompt instanceof PromptTemplate) {
qaPromptTemplate = new PromptTemplate({
template: qaPrompt.template as string,
inputVariables: qaPrompt.inputVariables
})
}

if ((!cypherModel || !qaModel) && !model) {
throw new Error('Language Model is required when Cypher Model or QA Model are not provided')
}

// Validate required variables in prompts
if (
cypherPromptTemplate &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here we check if cypherPromptTemplate is not undefined first

(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question'))
) {
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables')
}

const fromLLMInput: FromLLMInput = {
llm: model,
graph,
returnDirect
}

if (cypherModel && cypherPromptTemplate) {
fromLLMInput['cypherLLM'] = cypherModel
fromLLMInput['cypherPrompt'] = cypherPromptTemplate
}

if (qaModel && qaPromptTemplate) {
fromLLMInput['qaLLM'] = qaModel
fromLLMInput['qaPrompt'] = qaPromptTemplate
}

const chain = GraphCypherQAChain.fromLLM(fromLLMInput)

if (output === this.name) {
return chain
} else if (output === 'outputPrediction') {
nodeData.instance = chain
return await this.run(nodeData, input, options)
}

return chain
}

async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> {
const chain = nodeData.instance as GraphCypherQAChain
const moderations = nodeData.inputs?.inputModeration as Moderation[]
const returnDirect = nodeData.inputs?.returnDirect as boolean

const shouldStreamResponse = options.shouldStreamResponse
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer
const chatId = options.chatId

// Handle input moderation if configured
if (moderations && moderations.length > 0) {
try {
input = await checkInputs(moderations, input)
} catch (e) {
await new Promise((resolve) => setTimeout(resolve, 500))
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, e.message)
}
return formatResponse(e.message)
}
}

const obj = {
query: input
}

const loggerHandler = new ConsoleCallbackHandler(options.logger)
const callbackHandlers = await additionalCallbacks(nodeData, options)
let callbacks = [loggerHandler, ...callbackHandlers]

if (process.env.DEBUG === 'true') {
callbacks.push(new LCConsoleCallbackHandler())
}

try {
let response
if (shouldStreamResponse) {
if (returnDirect) {
response = await chain.invoke(obj, { callbacks })
let result = response?.result
if (typeof result === 'object') {
result = '```json\n' + JSON.stringify(result, null, 2)
}
if (result && typeof result === 'string') {
streamResponse(sseStreamer, chatId, result)
}
} else {
const handler = new CustomChainHandler(sseStreamer, chatId, 2)
callbacks.push(handler)
response = await chain.invoke(obj, { callbacks })
}
} else {
response = await chain.invoke(obj, { callbacks })
}

return formatResponse(response?.result)
} catch (error) {
console.error('Error in GraphCypherQAChain:', error)
if (shouldStreamResponse) {
streamResponse(sseStreamer, chatId, error.message)
}
return formatResponse(`Error: ${error.message}`)
}
}
}

module.exports = { nodeClass: GraphCypherQA_Chain }
22 changes: 22 additions & 0 deletions packages/components/nodes/chains/GraphCypherQAChain/graphqa.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Loading