Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typescript] Save output.data with text content instead of response data #603

Merged
merged 1 commit into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
ExecuteResult,
AIConfigRuntime,
InferenceOptions,
CallbackEvent
CallbackEvent,
} from "aiconfig";
import _ from "lodash";
import * as aiconfig from "aiconfig";
Expand Down Expand Up @@ -211,7 +211,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
const response = await this.hfClient.textGenerationStream(
textGenerationArgs
);
output = await ConstructStreamOutput(
output = await constructStreamOutput(
response,
options as InferenceOptions
);
Expand Down Expand Up @@ -248,11 +248,19 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
}

if (output.output_type === "execute_result") {
return (output.data as TextGenerationOutput | TextGenerationStreamOutput)
.generated_text as string;
} else {
return "";
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the response in output.data
if (output.data?.hasOwnProperty("generated_text")) {
return (
output.data as TextGenerationOutput | TextGenerationStreamOutput
).generated_text as string;
}
}
return "";
}
}

Expand All @@ -262,7 +270,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
* @param options
* @returns
*/
async function ConstructStreamOutput(
async function constructStreamOutput(
response: AsyncGenerator<TextGenerationStreamOutput>,
options: InferenceOptions
): Promise<Output> {
Expand All @@ -280,6 +288,8 @@ async function ConstructStreamOutput(

output = {
output_type: "execute_result",
// TODO: Investigate if we should use the accumulated message instead
// of delta: https://github.com/lastmile-ai/aiconfig/issues/620
data: delta,
execution_count: index,
metadata: metadata,
Expand All @@ -289,14 +299,11 @@ async function ConstructStreamOutput(
}

function constructOutput(response: TextGenerationOutput): Output {
const metadata = {};
const data = response;

const output = {
output_type: "execute_result",
data: data,
data: response.generated_text,
execution_count: 0,
metadata: metadata,
metadata: { rawResponse: response },
} as ExecuteResult;

return output;
Expand Down
8 changes: 6 additions & 2 deletions typescript/__tests__/parsers/hf/hf.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,13 @@ describe("HuggingFaceTextGeneration ModelParser", () => {

const expectedOutput = {
output_type: "execute_result",
data: { generated_text: "Test text generation" },
data: "Test text generation",
execution_count: 0,
metadata: {},
metadata: {
rawResponse: {
generated_text: "Test text generation",
},
},
};

expect(outputWithConfigParam).toEqual([expectedOutput]);
Expand Down
10 changes: 6 additions & 4 deletions typescript/demo/function-call-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,20 @@ async function functionCallingWithAIConfig() {
return;
}

const message = output.data as ChatCompletionMessageParam;
const rawResponse = output.metadata?.rawResponse;
const functionCall = rawResponse?.function_call;
console.log("functionCall=", functionCall);

// If there is no function call, we're done and can exit this loop
if (!message.function_call) {
if (!functionCall) {
return;
}

// If there is a function call, we generate a new message with the role 'function'.
const result = await callFunction(message.function_call);
const result = await callFunction(functionCall);
const newMessage = {
role: "function" as const,
name: message.function_call.name!,
name: functionCall.name!,
content: JSON.stringify(result),
};

Expand Down
32 changes: 19 additions & 13 deletions typescript/lib/parsers/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
const response = await this.hfClient.textGenerationStream(
textGenerationArgs
);
output = await ConstructStreamOutput(
output = await constructStreamOutput(
response,
options as InferenceOptions
);
Expand Down Expand Up @@ -240,11 +240,19 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
}

if (output.output_type === "execute_result") {
return (output.data as TextGenerationOutput | TextGenerationStreamOutput)
.generated_text as string;
} else {
return "";
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the response in output.data
if (output.data?.hasOwnProperty("generated_text")) {
return (
output.data as TextGenerationOutput | TextGenerationStreamOutput
).generated_text as string;
}
}
return "";
}
}

Expand All @@ -254,7 +262,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
* @param options
* @returns
*/
async function ConstructStreamOutput(
async function constructStreamOutput(
response: AsyncGenerator<TextGenerationStreamOutput>,
options: InferenceOptions
): Promise<Output> {
Expand All @@ -272,25 +280,23 @@ async function ConstructStreamOutput(

output = {
output_type: "execute_result",
// TODO: Investigate if we should use the accumulated message instead
// of delta: https://github.com/lastmile-ai/aiconfig/issues/620
data: delta,
execution_count: index,
metadata: metadata,
metadata: { metadata },
} as ExecuteResult;
}
return output;
}

function constructOutput(response: TextGenerationOutput): Output {
const metadata = {};
const data = response;

const output = {
output_type: "execute_result",
data: data,
data: response.generated_text,
execution_count: 0,
metadata: metadata,
metadata: { rawResponse: response },
} as ExecuteResult;

return output;
}

Expand Down
90 changes: 69 additions & 21 deletions typescript/lib/parsers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate

// Save response as Output(s) in the Prompt
const outputs: ExecuteResult[] = [];
const responseWithoutChoices = _.omit(response, "choices");
for (const choice of response.choices) {
const output: ExecuteResult = {
output_type: "execute_result",
Expand All @@ -179,7 +178,7 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate
metadata: {
finish_reason: choice.finish_reason,
logprobs: choice.logprobs,
...responseWithoutChoices,
rawResponse: response,
},
};
outputs.push(output);
Expand Down Expand Up @@ -253,10 +252,26 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate
}

if (output.output_type === "execute_result") {
return output.data as string;
} else {
return "";
// TODO: Add in OutputData another way to support function calls
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the ChatCompletionMessageParam in output.data
if (
output.data?.hasOwnProperty("content") &&
output.data?.hasOwnProperty("role")
) {
const message = output.data as Chat.ChatCompletionMessageParam;
if (message.content != null) {
return message.content;
} else if (message.function_call) {
return JSON.stringify(message.function_call);
}
}
}
return "";
}
}

Expand Down Expand Up @@ -378,7 +393,8 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
? [
{
output_type: "execute_result",
data: { ...assistantResponse },
data: assistantResponse.content,
metadata: { rawResponse: assistantResponse },
},
]
: undefined,
Expand Down Expand Up @@ -538,11 +554,12 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
for (const choice of response.choices) {
const output: ExecuteResult = {
output_type: "execute_result",
data: { ...choice.message },
data: choice.message?.content,
execution_count: choice.index,
metadata: {
finish_reason: choice.finish_reason,
...responseWithoutChoices,
rawResponse: choice.message,
},
};

Expand Down Expand Up @@ -587,10 +604,13 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom

const output: ExecuteResult = {
output_type: "execute_result",
data: { ...message },
// TODO (rossdanlm): Handle ChatCompletionMessage.function_call
// too (next diff)
data: message?.content,
execution_count: choice.index,
metadata: {
finish_reason: choice.finish_reason,
rawResponse: message,
},
};
outputs.set(choice.index, output);
Expand Down Expand Up @@ -625,17 +645,26 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
}

if (output.output_type === "execute_result") {
const message = output.data as Chat.ChatCompletionMessageParam;
if (message.content != null) {
return message.content;
} else if (message.function_call) {
return JSON.stringify(message.function_call);
} else {
return "";
// TODO: Add in OutputData another way to support function calls
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the ChatCompletionMessageParam in output.data
if (
output.data?.hasOwnProperty("content") &&
output.data?.hasOwnProperty("role")
) {
const message = output.data as Chat.ChatCompletionMessageParam;
if (message.content != null) {
return message.content;
} else if (message.function_call) {
return JSON.stringify(message.function_call);
}
}
} else {
return "";
}
return "";
}

private addPromptAsMessage(
Expand Down Expand Up @@ -671,11 +700,30 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
const output = aiConfig.getLatestOutput(prompt);
if (output != null) {
if (output.output_type === "execute_result") {
const outputMessage =
output.data as unknown as Chat.ChatCompletionMessageParam;
// If the prompt has output saved, add it to the messages array
if (outputMessage.role === "assistant") {
messages.push(outputMessage);
if (output.metadata?.rawResponse?.role === "assistant") {
if (typeof output.data === "string") {
messages.push({
content: output.data,
// TODO (rossdanlm): Support function_call and don't rely on rawResponse
role: output.metadata?.rawResponse?.role,
function_call: output.metadata?.rawResponse?.function_call,
name: output.metadata?.rawResponse?.name,
});
}
}

// Doing this to be backwards-compatible with old output format
// where we used to save the ChatCompletionMessageParam in output.data
else if (
output.data?.hasOwnProperty("content") &&
output.data?.hasOwnProperty("role")
) {
const outputMessage =
output.data as unknown as Chat.ChatCompletionMessageParam;
if (outputMessage.role === "assistant") {
messages.push(outputMessage);
}
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion typescript/lib/parsers/palm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ function constructOutputs(
output_type: "execute_result",
data: candidate.output,
execution_count: i,
metadata: _.omit(candidate, ["output"]),
metadata: { rawResponse: candidate },
};

outputs.push(output);
Expand Down
2 changes: 2 additions & 0 deletions typescript/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@
"@types/jest": "^29.5.10",
"@types/js-yaml": "^4.0.9",
"@types/lodash": "^4.14.197",
"@types/node": "^20.10.5",
"@typescript-eslint/eslint-plugin": "^6.7.2",
"@typescript-eslint/parser": "^6.7.2",
"dotenv": "^16.3.1",
"eslint": "^8.50.0",
"jest": "^29.7.0",
"ts-jest": "^29.1.1",
"ts-node": "^10.9.1",
"tslib": "^2.6.2",
"typedoc": "^0.23.27",
"typescript": "^4.9.5"
},
Expand Down
Loading