Skip to content

Commit 6e42e10

Browse files
authored
Merge pull request #576 from FlowiseAI/feature/Replicate
Feature/Add ReplicateLLM
2 parents 8d2b9cc + 5d561bf commit 6e42e10

File tree

4 files changed

+400
-1
lines changed

4 files changed

+400
-1
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
import { INode, INodeData, INodeParams } from '../../../src/Interface'
2+
import { getBaseClasses } from '../../../src/utils'
3+
import { Replicate, ReplicateInput } from 'langchain/llms/replicate'
4+
5+
class Replicate_LLMs implements INode {
6+
label: string
7+
name: string
8+
type: string
9+
icon: string
10+
category: string
11+
description: string
12+
baseClasses: string[]
13+
inputs: INodeParams[]
14+
15+
constructor() {
16+
this.label = 'Replicate'
17+
this.name = 'replicate'
18+
this.type = 'Replicate'
19+
this.icon = 'replicate.svg'
20+
this.category = 'LLMs'
21+
this.description = 'Use Replicate to run open source models on cloud'
22+
this.baseClasses = [this.type, 'BaseChatModel', ...getBaseClasses(Replicate)]
23+
this.inputs = [
24+
{
25+
label: 'Replicate Api Key',
26+
name: 'replicateApiKey',
27+
type: 'password'
28+
},
29+
{
30+
label: 'Model',
31+
name: 'model',
32+
type: 'string',
33+
placeholder: 'a16z-infra/llama13b-v2-chat:df7690f1994d94e96ad9d568eac121aecf50684a0b0963b25a41cc40061269e5',
34+
optional: true
35+
},
36+
{
37+
label: 'Temperature',
38+
name: 'temperature',
39+
type: 'number',
40+
description:
41+
'Adjusts randomness of outputs, greater than 1 is random and 0 is deterministic, 0.75 is a good starting value.',
42+
default: 0.7,
43+
optional: true
44+
},
45+
{
46+
label: 'Max Tokens',
47+
name: 'maxTokens',
48+
type: 'number',
49+
description: 'Maximum number of tokens to generate. A word is generally 2-3 tokens',
50+
optional: true,
51+
additionalParams: true
52+
},
53+
{
54+
label: 'Top Probability',
55+
name: 'topP',
56+
type: 'number',
57+
description:
58+
'When decoding text, samples from the top p percentage of most likely tokens; lower to ignore less likely tokens',
59+
optional: true,
60+
additionalParams: true
61+
},
62+
{
63+
label: 'Repetition Penalty',
64+
name: 'repetitionPenalty',
65+
type: 'number',
66+
description:
67+
'Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it. (minimum: 0.01; maximum: 5)',
68+
optional: true,
69+
additionalParams: true
70+
},
71+
{
72+
label: 'Additional Inputs',
73+
name: 'additionalInputs',
74+
type: 'json',
75+
description:
76+
'Each model has different parameters, refer to the specific model accepted inputs. For example: <a target="_blank" href="https://replicate.com/a16z-infra/llama13b-v2-chat/api#inputs">llama13b-v2</a>',
77+
additionalParams: true,
78+
optional: true
79+
}
80+
]
81+
}
82+
83+
async init(nodeData: INodeData): Promise<any> {
84+
const modelName = nodeData.inputs?.model as string
85+
const apiKey = nodeData.inputs?.replicateApiKey as string
86+
const temperature = nodeData.inputs?.temperature as string
87+
const maxTokens = nodeData.inputs?.maxTokens as string
88+
const topP = nodeData.inputs?.topP as string
89+
const repetitionPenalty = nodeData.inputs?.repetitionPenalty as string
90+
const additionalInputs = nodeData.inputs?.additionalInputs as string
91+
92+
const version = modelName.split(':').pop()
93+
const name = modelName.split(':')[0].split('/').pop()
94+
const org = modelName.split(':')[0].split('/')[0]
95+
96+
const obj: ReplicateInput = {
97+
model: `${org}/${name}:${version}`,
98+
apiKey
99+
}
100+
101+
let inputs: any = {}
102+
if (maxTokens) inputs.max_length = parseInt(maxTokens, 10)
103+
if (temperature) inputs.temperature = parseFloat(temperature)
104+
if (topP) inputs.top_p = parseFloat(topP)
105+
if (repetitionPenalty) inputs.repetition_penalty = parseFloat(repetitionPenalty)
106+
if (additionalInputs) {
107+
const parsedInputs =
108+
typeof additionalInputs === 'object' ? additionalInputs : additionalInputs ? JSON.parse(additionalInputs) : {}
109+
inputs = { ...inputs, ...parsedInputs }
110+
}
111+
if (Object.keys(inputs).length) obj.input = inputs
112+
113+
const model = new Replicate(obj)
114+
return model
115+
}
116+
}
117+
118+
module.exports = { nodeClass: Replicate_LLMs }
Loading

packages/components/package.json

+2-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
"form-data": "^4.0.0",
3838
"graphql": "^16.6.0",
3939
"html-to-text": "^9.0.5",
40-
"langchain": "^0.0.104",
40+
"langchain": "^0.0.112",
4141
"linkifyjs": "^4.1.1",
4242
"mammoth": "^1.5.1",
4343
"moment": "^2.29.3",
@@ -48,6 +48,7 @@
4848
"playwright": "^1.35.0",
4949
"puppeteer": "^20.7.1",
5050
"redis": "^4.6.7",
51+
"replicate": "^0.12.3",
5152
"srt-parser-2": "^1.2.3",
5253
"vm2": "^3.9.19",
5354
"weaviate-ts-client": "^1.1.0",

0 commit comments

Comments
 (0)