Skip to content

Commit

Permalink
Brace/runnable chain (#3200)
Browse files Browse the repository at this point in the history
* Runnable support for LLMChains

* fix typing

* nit

* nit

* fix build & typing

* chore: lint files

* cr

* cr

* cr

* no instance of

* Adds type guard

* fix signal bug

* cr

* cr

---------

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
bracesproul and jacoblee93 authored Nov 11, 2023
1 parent 719a131 commit 0cb640f
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 32 deletions.
4 changes: 3 additions & 1 deletion langchain/src/agents/openai/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ export class OpenAIAgent extends Agent {
const valuesForLLM: (typeof llm)["CallOptions"] = {
functions: this.tools.map(formatToOpenAIFunction),
};
for (const key of this.llmChain.llm.callKeys) {
const callKeys =
"callKeys" in this.llmChain.llm ? this.llmChain.llm.callKeys : [];
for (const key of callKeys) {
if (key in inputs) {
valuesForLLM[key as keyof (typeof llm)["CallOptions"]] = inputs[key];
delete valuesForPrompt[key];
Expand Down
5 changes: 3 additions & 2 deletions langchain/src/chains/combine_docs_chain.ts
Original file line number Diff line number Diff line change
Expand Up @@ -206,8 +206,9 @@ export class MapReduceDocumentsChain
...rest,
})
);
const length =
await this.combineDocumentChain.llmChain.llm.getNumTokens(formatted);
const length = await this.combineDocumentChain.llmChain._getNumTokens(
formatted
);

const withinTokenLimit = length < this.maxTokens;
// If we can skip the map step, and we're within the token limit, we don't
Expand Down
121 changes: 94 additions & 27 deletions langchain/src/chains/llm_chain.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import { BaseChain, ChainInputs } from "./base.js";
import { BasePromptTemplate } from "../prompts/base.js";
import { BaseLanguageModel } from "../base_language/index.js";
import { ChainValues, Generation, BasePromptValue } from "../schema/index.js";
import {
BaseLanguageModel,
BaseLanguageModelInput,
} from "../base_language/index.js";
import {
ChainValues,
Generation,
BasePromptValue,
BaseMessage,
} from "../schema/index.js";
import {
BaseLLMOutputParser,
BaseOutputParser,
Expand All @@ -14,26 +22,56 @@ import {
Callbacks,
} from "../callbacks/manager.js";
import { NoOpOutputParser } from "../output_parsers/noop.js";
import { Runnable } from "../schema/runnable/base.js";

type LLMType =
| BaseLanguageModel
| Runnable<BaseLanguageModelInput, string>
| Runnable<BaseLanguageModelInput, BaseMessage>;

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type CallOptionsIfAvailable<T> = T extends { CallOptions: infer CO } ? CO : any;
/**
* Interface for the input parameters of the LLMChain class.
*/
export interface LLMChainInput<
T extends string | object = string,
L extends BaseLanguageModel = BaseLanguageModel
Model extends LLMType = LLMType
> extends ChainInputs {
/** Prompt object to use */
prompt: BasePromptTemplate;
/** LLM Wrapper to use */
llm: L;
llm: Model;
/** Kwargs to pass to LLM */
llmKwargs?: this["llm"]["CallOptions"];
llmKwargs?: CallOptionsIfAvailable<Model>;
/** OutputParser to use */
outputParser?: BaseLLMOutputParser<T>;
/** Key to use for output, defaults to `text` */
outputKey?: string;
}

function isBaseLanguageModel(llmLike: unknown): llmLike is BaseLanguageModel {
return typeof (llmLike as BaseLanguageModel)._llmType === "function";
}

function _getLanguageModel(llmLike: Runnable): BaseLanguageModel {
if (isBaseLanguageModel(llmLike)) {
return llmLike;
} else if ("bound" in llmLike && Runnable.isRunnable(llmLike.bound)) {
return _getLanguageModel(llmLike.bound);
} else if (
"runnable" in llmLike &&
"fallbacks" in llmLike &&
Runnable.isRunnable(llmLike.runnable)
) {
return _getLanguageModel(llmLike.runnable);
} else if ("default" in llmLike && Runnable.isRunnable(llmLike.default)) {
return _getLanguageModel(llmLike.default);
} else {
throw new Error("Unable to extract BaseLanguageModel from llmLike object.");
}
}

/**
* Chain to run queries against LLMs.
*
Expand All @@ -49,7 +87,7 @@ export interface LLMChainInput<
*/
export class LLMChain<
T extends string | object = string,
L extends BaseLanguageModel = BaseLanguageModel
Model extends LLMType = LLMType
>
extends BaseChain
implements LLMChainInput<T>
Expand All @@ -62,9 +100,9 @@ export class LLMChain<

prompt: BasePromptTemplate;

llm: L;
llm: Model;

llmKwargs?: this["llm"]["CallOptions"];
llmKwargs?: CallOptionsIfAvailable<Model>;

outputKey = "text";

Expand All @@ -78,7 +116,7 @@ export class LLMChain<
return [this.outputKey];
}

constructor(fields: LLMChainInput<T, L>) {
constructor(fields: LLMChainInput<T, Model>) {
super(fields);
this.prompt = fields.prompt;
this.llm = fields.llm;
Expand All @@ -94,10 +132,16 @@ export class LLMChain<
}
}

private getCallKeys(): string[] {
const callKeys = "callKeys" in this.llm ? this.llm.callKeys : [];
return callKeys;
}

/** @ignore */
_selectMemoryInputs(values: ChainValues): ChainValues {
const valuesForMemory = super._selectMemoryInputs(values);
for (const key of this.llm.callKeys) {
const callKeys = this.getCallKeys();
for (const key of callKeys) {
if (key in values) {
delete valuesForMemory[key];
}
Expand Down Expand Up @@ -130,39 +174,56 @@ export class LLMChain<
* Wraps _call and handles memory.
*/
call(
values: ChainValues & this["llm"]["CallOptions"],
values: ChainValues & CallOptionsIfAvailable<Model>,
config?: Callbacks | BaseCallbackConfig
): Promise<ChainValues> {
return super.call(values, config);
}

/** @ignore */
async _call(
values: ChainValues & this["llm"]["CallOptions"],
values: ChainValues & CallOptionsIfAvailable<Model>,
runManager?: CallbackManagerForChainRun
): Promise<ChainValues> {
const valuesForPrompt = { ...values };
const valuesForLLM: this["llm"]["CallOptions"] = {
const valuesForLLM = {
...this.llmKwargs,
};
for (const key of this.llm.callKeys) {
} as CallOptionsIfAvailable<Model>;
const callKeys = this.getCallKeys();
for (const key of callKeys) {
if (key in values) {
valuesForLLM[key as keyof this["llm"]["CallOptions"]] = values[key];
delete valuesForPrompt[key];
if (valuesForLLM) {
valuesForLLM[key as keyof CallOptionsIfAvailable<Model>] =
values[key];
delete valuesForPrompt[key];
}
}
}
const promptValue = await this.prompt.formatPromptValue(valuesForPrompt);
const { generations } = await this.llm.generatePrompt(
[promptValue],
valuesForLLM,
if ("generatePrompt" in this.llm) {
const { generations } = await this.llm.generatePrompt(
[promptValue],
valuesForLLM,
runManager?.getChild()
);
return {
[this.outputKey]: await this._getFinalOutput(
generations[0],
promptValue,
runManager
),
};
}

const modelWithParser = this.outputParser
? this.llm.pipe(this.outputParser)
: this.llm;
const response = await modelWithParser.invoke(
promptValue,
runManager?.getChild()
);
return {
[this.outputKey]: await this._getFinalOutput(
generations[0],
promptValue,
runManager
),
[this.outputKey]: response,
};
}

Expand All @@ -179,7 +240,7 @@ export class LLMChain<
* ```
*/
async predict(
values: ChainValues & this["llm"]["CallOptions"],
values: ChainValues & CallOptionsIfAvailable<Model>,
callbackManager?: CallbackManager
): Promise<T> {
const output = await this.call(values, callbackManager);
Expand Down Expand Up @@ -207,10 +268,16 @@ export class LLMChain<

/** @deprecated */
serialize(): SerializedLLMChain {
const serialize =
"serialize" in this.llm ? this.llm.serialize() : undefined;
return {
_type: `${this._chainType()}_chain`,
llm: this.llm.serialize(),
llm: serialize,
prompt: this.prompt.serialize(),
};
}

_getNumTokens(text: string): Promise<number> {
return _getLanguageModel(this.llm).getNumTokens(text);
}
}
17 changes: 17 additions & 0 deletions langchain/src/chains/tests/llm_chain.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,20 @@ test("Test deserialize", async () => {

// chain === chain2?
});

test("Test passing a runnable to an LLMChain", async () => {
const model = new ChatOpenAI({ modelName: "gpt-3.5-turbo-1106" });
const runnableModel = model.bind({
response_format: {
type: "json_object",
},
});
const prompt = PromptTemplate.fromTemplate(
"You are a bee --I mean a spelling bee. Respond with a JSON key of 'spelling':\nQuestion:{input}"
);
const chain = new LLMChain({ llm: runnableModel, prompt });
const response = await chain.invoke({ input: "How do you spell today?" });
expect(JSON.parse(response.text)).toMatchObject({
spelling: expect.any(String),
});
});
4 changes: 2 additions & 2 deletions langchain/src/schema/runnable/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1395,9 +1395,9 @@ export class RunnableWithFallbacks<RunInput, RunOutput> extends Runnable<

lc_serializable = true;

protected runnable: Runnable<RunInput, RunOutput>;
runnable: Runnable<RunInput, RunOutput>;

protected fallbacks: Runnable<RunInput, RunOutput>[];
fallbacks: Runnable<RunInput, RunOutput>[];

constructor(fields: {
runnable: Runnable<RunInput, RunOutput>;
Expand Down

0 comments on commit 0cb640f

Please sign in to comment.