-
-
Notifications
You must be signed in to change notification settings - Fork 19.6k
Feature/Add Neo4j GraphRag support #3686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 11 commits
a45b035
2489d45
441b5c5
d48d4c3
d8decef
0ea4784
2f3f3ad
0f70840
e35aaf0
287849a
4ed5dd4
0a9c041
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import { INodeParams, INodeCredential } from '../src/Interface' | ||
|
||
class Neo4jApi implements INodeCredential { | ||
label: string | ||
name: string | ||
version: number | ||
description: string | ||
inputs: INodeParams[] | ||
|
||
constructor() { | ||
this.label = 'Neo4j API' | ||
this.name = 'neo4jApi' | ||
this.version = 1.0 | ||
this.description = | ||
'Refer to <a target="_blank" href="https://neo4j.com/docs/operations-manual/current/authentication-authorization/">official guide</a> on Neo4j authentication' | ||
this.inputs = [ | ||
{ | ||
label: 'Neo4j URL', | ||
name: 'url', | ||
type: 'string', | ||
description: 'Your Neo4j instance URL (e.g., neo4j://localhost:7687)' | ||
}, | ||
{ | ||
label: 'Username', | ||
name: 'username', | ||
type: 'string', | ||
description: 'Neo4j database username' | ||
}, | ||
{ | ||
label: 'Password', | ||
name: 'password', | ||
type: 'password', | ||
description: 'Neo4j database password' | ||
} | ||
] | ||
} | ||
} | ||
|
||
module.exports = { credClass: Neo4jApi } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
import { ICommonObject, INode, INodeData, INodeParams, INodeOutputsValue, IServerSideEventStreamer } from '../../../src/Interface' | ||
import { FromLLMInput, GraphCypherQAChain } from '@langchain/community/chains/graph_qa/cypher' | ||
import { getBaseClasses } from '../../../src/utils' | ||
import { BasePromptTemplate, PromptTemplate, FewShotPromptTemplate } from '@langchain/core/prompts' | ||
import { ConsoleCallbackHandler, CustomChainHandler, additionalCallbacks } from '../../../src/handler' | ||
import { ConsoleCallbackHandler as LCConsoleCallbackHandler } from '@langchain/core/tracers/console' | ||
import { checkInputs, Moderation, streamResponse } from '../../moderation/Moderation' | ||
import { formatResponse } from '../../outputparsers/OutputParserHelpers' | ||
|
||
class GraphCypherQA_Chain implements INode { | ||
label: string | ||
name: string | ||
version: number | ||
type: string | ||
icon: string | ||
category: string | ||
description: string | ||
baseClasses: string[] | ||
inputs: INodeParams[] | ||
sessionId?: string | ||
outputs: INodeOutputsValue[] | ||
|
||
constructor(fields?: { sessionId?: string }) { | ||
this.label = 'Graph Cypher QA Chain' | ||
this.name = 'graphCypherQAChain' | ||
this.version = 1.0 | ||
this.type = 'GraphCypherQAChain' | ||
this.icon = 'graphqa.svg' | ||
this.category = 'Chains' | ||
this.description = 'Advanced chain for question-answering against a Neo4j graph by generating Cypher statements' | ||
this.baseClasses = [this.type, ...getBaseClasses(GraphCypherQAChain)] | ||
this.sessionId = fields?.sessionId | ||
this.inputs = [ | ||
{ | ||
label: 'Language Model', | ||
name: 'model', | ||
type: 'BaseLanguageModel', | ||
description: 'Model for generating Cypher queries and answers.' | ||
}, | ||
{ | ||
label: 'Neo4j Graph', | ||
name: 'graph', | ||
type: 'Neo4j' | ||
}, | ||
{ | ||
label: 'Cypher Generation Prompt', | ||
name: 'cypherPrompt', | ||
optional: true, | ||
type: 'BasePromptTemplate', | ||
description: 'Prompt template for generating Cypher queries. Must include {schema} and {question} variables' | ||
}, | ||
{ | ||
label: 'Cypher Generation Model', | ||
name: 'cypherModel', | ||
optional: true, | ||
type: 'BaseLanguageModel', | ||
description: 'Model for generating Cypher queries. If not provided, the main model will be used.' | ||
}, | ||
{ | ||
label: 'QA Prompt', | ||
name: 'qaPrompt', | ||
optional: true, | ||
type: 'BasePromptTemplate', | ||
description: 'Prompt template for generating answers. Must include {context} and {question} variables' | ||
}, | ||
{ | ||
label: 'QA Model', | ||
name: 'qaModel', | ||
optional: true, | ||
type: 'BaseLanguageModel', | ||
description: 'Model for generating answers. If not provided, the main model will be used.' | ||
}, | ||
{ | ||
label: 'Input Moderation', | ||
description: 'Detect text that could generate harmful output and prevent it from being sent to the language model', | ||
name: 'inputModeration', | ||
type: 'Moderation', | ||
optional: true, | ||
list: true | ||
}, | ||
{ | ||
label: 'Return Direct', | ||
name: 'returnDirect', | ||
type: 'boolean', | ||
default: false, | ||
optional: true, | ||
description: 'If true, return the raw query results instead of using the QA chain' | ||
} | ||
] | ||
this.outputs = [ | ||
{ | ||
label: 'Graph Cypher QA Chain', | ||
name: 'graphCypherQAChain', | ||
baseClasses: [this.type, ...getBaseClasses(GraphCypherQAChain)] | ||
}, | ||
{ | ||
label: 'Output Prediction', | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here we allow this chain to output a string or json, but we dont have any function that checks for the output You can reference LLMChain There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I took RetrievalQAChain as a reference and removed Output Prediction There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added it back. Works now! |
||
name: 'outputPrediction', | ||
baseClasses: ['string', 'json'] | ||
} | ||
] | ||
} | ||
|
||
async init(nodeData: INodeData, input: string, options: ICommonObject): Promise<any> { | ||
const model = nodeData.inputs?.model | ||
const cypherModel = nodeData.inputs?.cypherModel | ||
const qaModel = nodeData.inputs?.qaModel | ||
const graph = nodeData.inputs?.graph | ||
const cypherPrompt = nodeData.inputs?.cypherPrompt as BasePromptTemplate | FewShotPromptTemplate | undefined | ||
const qaPrompt = nodeData.inputs?.qaPrompt as BasePromptTemplate | undefined | ||
const returnDirect = nodeData.inputs?.returnDirect as boolean | ||
const output = nodeData.outputs?.output as string | ||
|
||
// Handle prompt values if they exist | ||
let cypherPromptTemplate: PromptTemplate | FewShotPromptTemplate | undefined | ||
let qaPromptTemplate: PromptTemplate | undefined | ||
|
||
if (cypherPrompt) { | ||
if (cypherPrompt instanceof PromptTemplate) { | ||
cypherPromptTemplate = new PromptTemplate({ | ||
template: cypherPrompt.template as string, | ||
inputVariables: cypherPrompt.inputVariables | ||
}) | ||
if (!qaPrompt) { | ||
throw new Error('QA Prompt is required when Cypher Prompt is a Prompt Template') | ||
} | ||
} else if (cypherPrompt instanceof FewShotPromptTemplate) { | ||
const examplePrompt = cypherPrompt.examplePrompt as PromptTemplate | ||
cypherPromptTemplate = new FewShotPromptTemplate({ | ||
examples: cypherPrompt.examples, | ||
examplePrompt: examplePrompt, | ||
inputVariables: cypherPrompt.inputVariables, | ||
prefix: cypherPrompt.prefix, | ||
suffix: cypherPrompt.suffix, | ||
exampleSeparator: cypherPrompt.exampleSeparator, | ||
templateFormat: cypherPrompt.templateFormat | ||
}) | ||
} else { | ||
cypherPromptTemplate = cypherPrompt as PromptTemplate | ||
} | ||
} | ||
|
||
if (qaPrompt instanceof PromptTemplate) { | ||
qaPromptTemplate = new PromptTemplate({ | ||
template: qaPrompt.template as string, | ||
inputVariables: qaPrompt.inputVariables | ||
}) | ||
} | ||
|
||
if ((!cypherModel || !qaModel) && !model) { | ||
throw new Error('Language Model is required when Cypher Model or QA Model are not provided') | ||
} | ||
|
||
// Validate required variables in prompts | ||
if ( | ||
cypherPromptTemplate && | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. here we check if |
||
(!cypherPromptTemplate?.inputVariables.includes('schema') || !cypherPromptTemplate?.inputVariables.includes('question')) | ||
) { | ||
throw new Error('Cypher Generation Prompt must include {schema} and {question} variables') | ||
} | ||
|
||
const fromLLMInput: FromLLMInput = { | ||
llm: model, | ||
graph, | ||
returnDirect | ||
} | ||
|
||
if (cypherModel && cypherPromptTemplate) { | ||
fromLLMInput['cypherLLM'] = cypherModel | ||
fromLLMInput['cypherPrompt'] = cypherPromptTemplate | ||
} | ||
|
||
if (qaModel && qaPromptTemplate) { | ||
fromLLMInput['qaLLM'] = qaModel | ||
fromLLMInput['qaPrompt'] = qaPromptTemplate | ||
} | ||
|
||
const chain = GraphCypherQAChain.fromLLM(fromLLMInput) | ||
|
||
if (output === this.name) { | ||
return chain | ||
} else if (output === 'outputPrediction') { | ||
nodeData.instance = chain | ||
return await this.run(nodeData, input, options) | ||
} | ||
|
||
return chain | ||
} | ||
|
||
async run(nodeData: INodeData, input: string, options: ICommonObject): Promise<string | object> { | ||
const chain = nodeData.instance as GraphCypherQAChain | ||
const moderations = nodeData.inputs?.inputModeration as Moderation[] | ||
const returnDirect = nodeData.inputs?.returnDirect as boolean | ||
|
||
const shouldStreamResponse = options.shouldStreamResponse | ||
const sseStreamer: IServerSideEventStreamer = options.sseStreamer as IServerSideEventStreamer | ||
const chatId = options.chatId | ||
|
||
// Handle input moderation if configured | ||
if (moderations && moderations.length > 0) { | ||
try { | ||
input = await checkInputs(moderations, input) | ||
} catch (e) { | ||
await new Promise((resolve) => setTimeout(resolve, 500)) | ||
if (shouldStreamResponse) { | ||
streamResponse(sseStreamer, chatId, e.message) | ||
} | ||
return formatResponse(e.message) | ||
} | ||
} | ||
|
||
const obj = { | ||
query: input | ||
} | ||
|
||
const loggerHandler = new ConsoleCallbackHandler(options.logger) | ||
const callbackHandlers = await additionalCallbacks(nodeData, options) | ||
let callbacks = [loggerHandler, ...callbackHandlers] | ||
|
||
if (process.env.DEBUG === 'true') { | ||
callbacks.push(new LCConsoleCallbackHandler()) | ||
} | ||
|
||
try { | ||
let response | ||
if (shouldStreamResponse) { | ||
if (returnDirect) { | ||
response = await chain.invoke(obj, { callbacks }) | ||
let result = response?.result | ||
if (typeof result === 'object') { | ||
result = '```json\n' + JSON.stringify(result, null, 2) | ||
} | ||
if (result && typeof result === 'string') { | ||
streamResponse(sseStreamer, chatId, result) | ||
} | ||
} else { | ||
const handler = new CustomChainHandler(sseStreamer, chatId, 2) | ||
callbacks.push(handler) | ||
response = await chain.invoke(obj, { callbacks }) | ||
} | ||
} else { | ||
response = await chain.invoke(obj, { callbacks }) | ||
} | ||
|
||
return formatResponse(response?.result) | ||
} catch (error) { | ||
console.error('Error in GraphCypherQAChain:', error) | ||
if (shouldStreamResponse) { | ||
streamResponse(sseStreamer, chatId, error.message) | ||
} | ||
return formatResponse(`Error: ${error.message}`) | ||
} | ||
} | ||
} | ||
|
||
module.exports = { nodeClass: GraphCypherQA_Chain } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can have all of these prompts as
optional
.in future, we will be adding conditional input
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, Cypher Generation Prompt now will be optional