Skip to content

Commit

Permalink
[typescript] Save output.data with text content instead of response data
Browse files Browse the repository at this point in the history
This comes after Sarmad's schema updates in #589. To keep diffs small and easier to review, this simply converts from model-specific outputs --> pure text. I have a diff in #610 which converts from pure text --> `OutputData` format.


We only needed to update the `hf.py` and `openai.py`, because `palm.py` already returns output in the form of `string | null` type.

Ran yarn automated tests, but there aren't any specifically for openai. I also ran the typescript demos to make sure that they still work. Run these commands from `aiconfig` top-level dir:
```
npx ts-node typescript/demo/function-call-stream.ts
npx ts-node typescript/demo/demo.ts
npx ts-node typescript/demo/test-hf.ts
```


For the extensions, we only have typescript for `hf.ts` (trivial: just changed `response` to `response.generated_text`), while `llama.ts` already outputs it in text format so no changes needed


## TODO
I still need to add function call support directly to `OutputData` format. See
  • Loading branch information
Rossdan Craig [email protected] committed Dec 26, 2023
1 parent 5e336f0 commit 15c2e66
Show file tree
Hide file tree
Showing 10 changed files with 6,956 additions and 122 deletions.
39 changes: 39 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
{
"name": "exploring nyc through chatgpt config",
"description": "",
"schema_version": "latest",
"metadata": {
"parameters": {},
"models": {
"mistralai/Mistral-7B-v0.1": {
"model": "mistralai/Mistral-7B-v0.1",
"top_p": 0.9,
"temperature": 0.9,
"stream": true
}
}
},
"prompts": [
{
"name": "prompt1",
"input": "Hi! Tell me 10 cool things to do in NYC.",
"metadata": {
"model": {
"name": "mistralai/Mistral-7B-v0.1"
},
"remember_chat_context": true
}
},
{
"name": "prompt2",
"input": "Hello, world!",
"metadata": {
"model": {
"name": "HuggingFaceTextGenerationParser"
},
"parameters": {}
}
}
],
"$schema": "https://json.schemastore.org/aiconfig-1.0"
}
31 changes: 19 additions & 12 deletions extensions/HuggingFace/typescript/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
ExecuteResult,
AIConfigRuntime,
InferenceOptions,
CallbackEvent
CallbackEvent,
} from "aiconfig";
import _ from "lodash";
import * as aiconfig from "aiconfig";
Expand Down Expand Up @@ -211,7 +211,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
const response = await this.hfClient.textGenerationStream(
textGenerationArgs
);
output = await ConstructStreamOutput(
output = await constructStreamOutput(
response,
options as InferenceOptions
);
Expand Down Expand Up @@ -248,11 +248,19 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
}

if (output.output_type === "execute_result") {
return (output.data as TextGenerationOutput | TextGenerationStreamOutput)
.generated_text as string;
} else {
return "";
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the response in output.data
if (output.data?.hasOwnProperty("generated_text")) {
return (
output.data as TextGenerationOutput | TextGenerationStreamOutput
).generated_text as string;
}
}
return "";
}
}

Expand All @@ -262,7 +270,7 @@ export class HuggingFaceTextGenerationModelParserExtension extends Parameterized
* @param options
* @returns
*/
async function ConstructStreamOutput(
async function constructStreamOutput(
response: AsyncGenerator<TextGenerationStreamOutput>,
options: InferenceOptions
): Promise<Output> {
Expand All @@ -280,6 +288,8 @@ async function ConstructStreamOutput(

output = {
output_type: "execute_result",
// TODO: Investigate if we should use the accumulated message instead
// of delta: https://github.com/lastmile-ai/aiconfig/issues/620
data: delta,
execution_count: index,
metadata: metadata,
Expand All @@ -289,14 +299,11 @@ async function ConstructStreamOutput(
}

function constructOutput(response: TextGenerationOutput): Output {
const metadata = {};
const data = response;

const output = {
output_type: "execute_result",
data: data,
data: response.generated_text,
execution_count: 0,
metadata: metadata,
metadata: { rawResponse: response },
} as ExecuteResult;

return output;
Expand Down
8 changes: 6 additions & 2 deletions typescript/__tests__/parsers/hf/hf.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,13 @@ describe("HuggingFaceTextGeneration ModelParser", () => {

const expectedOutput = {
output_type: "execute_result",
data: { generated_text: "Test text generation" },
data: "Test text generation",
execution_count: 0,
metadata: {},
metadata: {
rawResponse: {
generated_text: "Test text generation",
},
},
};

expect(outputWithConfigParam).toEqual([expectedOutput]);
Expand Down
10 changes: 6 additions & 4 deletions typescript/demo/function-call-stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -217,18 +217,20 @@ async function functionCallingWithAIConfig() {
return;
}

const message = output.data as ChatCompletionMessageParam;
const rawResponse = output.metadata?.rawResponse;
const function_call = rawResponse?.function_call;
console.log("function_call=", function_call);

// If there is no function call, we're done and can exit this loop
if (!message.function_call) {
if (!function_call) {
return;
}

// If there is a function call, we generate a new message with the role 'function'.
const result = await callFunction(message.function_call);
const result = await callFunction(function_call);
const newMessage = {
role: "function" as const,
name: message.function_call.name!,
name: function_call.name!,
content: JSON.stringify(result),
};

Expand Down
32 changes: 19 additions & 13 deletions typescript/lib/parsers/hf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
const response = await this.hfClient.textGenerationStream(
textGenerationArgs
);
output = await ConstructStreamOutput(
output = await constructStreamOutput(
response,
options as InferenceOptions
);
Expand Down Expand Up @@ -240,11 +240,19 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
}

if (output.output_type === "execute_result") {
return (output.data as TextGenerationOutput | TextGenerationStreamOutput)
.generated_text as string;
} else {
return "";
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the response in output.data
if (output.data?.hasOwnProperty("generated_text")) {
return (
output.data as TextGenerationOutput | TextGenerationStreamOutput
).generated_text as string;
}
}
return "";
}
}

Expand All @@ -254,7 +262,7 @@ export class HuggingFaceTextGenerationParser extends ParameterizedModelParser<Te
* @param options
* @returns
*/
async function ConstructStreamOutput(
async function constructStreamOutput(
response: AsyncGenerator<TextGenerationStreamOutput>,
options: InferenceOptions
): Promise<Output> {
Expand All @@ -272,25 +280,23 @@ async function ConstructStreamOutput(

output = {
output_type: "execute_result",
// TODO: Investigate if we should use the accumulated message instead
// of delta: https://github.com/lastmile-ai/aiconfig/issues/620
data: delta,
execution_count: index,
metadata: metadata,
metadata: { metadata },
} as ExecuteResult;
}
return output;
}

function constructOutput(response: TextGenerationOutput): Output {
const metadata = {};
const data = response;

const output = {
output_type: "execute_result",
data: data,
data: response.generated_text,
execution_count: 0,
metadata: metadata,
metadata: { rawResponse: response },
} as ExecuteResult;

return output;
}

Expand Down
55 changes: 37 additions & 18 deletions typescript/lib/parsers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,6 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate

// Save response as Output(s) in the Prompt
const outputs: ExecuteResult[] = [];
const responseWithoutChoices = _.omit(response, "choices");
for (const choice of response.choices) {
const output: ExecuteResult = {
output_type: "execute_result",
Expand All @@ -179,7 +178,7 @@ export class OpenAIModelParser extends ParameterizedModelParser<CompletionCreate
metadata: {
finish_reason: choice.finish_reason,
logprobs: choice.logprobs,
...responseWithoutChoices,
rawResponse: response,
},
};
outputs.push(output);
Expand Down Expand Up @@ -378,7 +377,8 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
? [
{
output_type: "execute_result",
data: { ...assistantResponse },
data: assistantResponse.content,
metadata: { rawResponse: assistantResponse },
},
]
: undefined,
Expand Down Expand Up @@ -538,11 +538,12 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
for (const choice of response.choices) {
const output: ExecuteResult = {
output_type: "execute_result",
data: { ...choice.message },
data: choice.message?.content,
execution_count: choice.index,
metadata: {
finish_reason: choice.finish_reason,
...responseWithoutChoices,
rawResponse: choice.message,
},
};

Expand Down Expand Up @@ -587,10 +588,13 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom

const output: ExecuteResult = {
output_type: "execute_result",
data: { ...message },
// TODO (rossdanlm): Handle ChatCompletionMessage.function_call
// too (next diff)
data: message?.content,
execution_count: choice.index,
metadata: {
finish_reason: choice.finish_reason,
rawResponse: message,
},
};
outputs.set(choice.index, output);
Expand Down Expand Up @@ -625,17 +629,26 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
}

if (output.output_type === "execute_result") {
const message = output.data as Chat.ChatCompletionMessageParam;
if (message.content != null) {
return message.content;
} else if (message.function_call) {
return JSON.stringify(message.function_call);
} else {
return "";
// TODO: Add in OutputData another way to support function calls
if (typeof output.data === "string") {
return output.data;
}

// Doing this to be backwards-compatible with old output format
// where we used to save the ChatCompletionMessageParam in output.data
if (
output.data?.hasOwnProperty("content") &&
output.data?.hasOwnProperty("role")
) {
const message = output.data as Chat.ChatCompletionMessageParam;
if (message.content != null) {
return message.content;
} else if (message.function_call) {
return JSON.stringify(message.function_call);
}
}
} else {
return "";
}
return "";
}

private addPromptAsMessage(
Expand Down Expand Up @@ -671,11 +684,17 @@ export class OpenAIChatModelParser extends ParameterizedModelParser<Chat.ChatCom
const output = aiConfig.getLatestOutput(prompt);
if (output != null) {
if (output.output_type === "execute_result") {
const outputMessage =
output.data as unknown as Chat.ChatCompletionMessageParam;
// If the prompt has output saved, add it to the messages array
if (outputMessage.role === "assistant") {
messages.push(outputMessage);
if (output.metadata?.role === "assistant") {
if (typeof output.data === "string") {
messages.push({
content: output.data,
role: output.metadata?.role,
function_call: output.metadata?.function_call,
name: output.metadata?.name,
});
}
// TODO (rossdanlm): Support function_call
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion typescript/lib/parsers/palm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ function constructOutputs(
output_type: "execute_result",
data: candidate.output,
execution_count: i,
metadata: _.omit(candidate, ["output"]),
metadata: { rawResponse: candidate },
};

outputs.push(output);
Expand Down
Loading

0 comments on commit 15c2e66

Please sign in to comment.