Skip to content

Commit

Permalink
feat: getFunctionCalls() -b closes (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
gr2m authored Sep 3, 2024
1 parent bb4e12c commit be9c476
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 5 deletions.
17 changes: 17 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,23 @@ await prompt({
});
```
### `getFunctionCalls()`
Convenience metthod if a result from a `prompt()` call includes function calls.
```js
import { prompt, getFunctionCalls } from "@copilot-extensions/preview-sdk";

const result = await prompt(options);
const [functionCall] = getFunctionCalls(result);

if (functionCall) {
console.log("Received a function call", functionCall);
} else {
console.log("No function call received");
}
```
## Dreamcode
While implementing the lower-level functionality, we also dream big: what would our dream SDK for Coplitot extensions look like? Please have a look and share your thoughts and ideas:
Expand Down
13 changes: 12 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,16 @@ interface PromptInterface {
(options: WithRequired<PromptOptions, "messages">): Promise<PromptResult>;
}

interface GetFunctionCallsInterface {
(payload: PromptResult): {
id: string;
function: {
name: string,
arguments: string,
}
}[]
}

// exported methods

export declare const verifyRequest: VerifyRequestInterface;
Expand All @@ -325,4 +335,5 @@ export declare const verifyAndParseRequest: VerifyAndParseRequestInterface;
export declare const getUserMessage: GetUserMessageInterface;
export declare const getUserConfirmation: GetUserConfirmationInterface;

export declare const prompt: PromptInterface;
export declare const prompt: PromptInterface;
export declare const getFunctionCalls: GetFunctionCallsInterface;
13 changes: 13 additions & 0 deletions index.test-d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import {
type InteropMessage,
CopilotRequestPayload,
prompt,
PromptResult,
getFunctionCalls,
} from "./index.js";

const token = "";
Expand Down Expand Up @@ -335,4 +337,15 @@ export async function promptWithoutMessageButMessages() {
{ role: "user", content: "What about Spain?" },
],
});
}

export async function getFunctionCallsTest(promptResponsePayload: PromptResult) {
const result = getFunctionCalls(promptResponsePayload)

expectType<{
id: string, function: {
name: string,
arguments: string,
}
}[]>(result)
}
34 changes: 31 additions & 3 deletions lib/prompt.js
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,38 @@ export async function prompt(userPrompt, promptOptions) {
}
);

const data = await response.json();
if (response.ok) {
const data = await response.json();

return {
requestId: response.headers.get("x-request-id"),
message: data.choices[0].message,
};
}

const requestId = response.headers.get("x-request-id");
return {
requestId: response.headers.get("x-request-id"),
message: data.choices[0].message,
requestId: requestId,
message: {
role: "Sssistant",
content: `Sorry, an error occured with the chat completions API. (Status: ${response.status}, request ID: ${requestId})`,
},
};
}

/** @type {import('..').GetFunctionCallsInterface} */
export function getFunctionCalls(payload) {
const functionCalls = payload.message.tool_calls;

if (!functionCalls) return [];

return functionCalls.map((call) => {
return {
id: call.id,
function: {
name: call.function.name,
arguments: call.function.arguments,
},
};
});
}
100 changes: 99 additions & 1 deletion test/prompt.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { test, suite } from "node:test";

import { MockAgent } from "undici";

import { prompt } from "../index.js";
import { prompt, getFunctionCalls } from "../index.js";

suite("prompt", () => {
test("smoke", (t) => {
Expand Down Expand Up @@ -267,4 +267,102 @@ suite("prompt", () => {
},
});
});

test("Handles error", async (t) => {
const mockAgent = new MockAgent();
function fetchMock(url, opts) {
opts ||= {};
opts.dispatcher = mockAgent;
return fetch(url, opts);
}

mockAgent.disableNetConnect();
const mockPool = mockAgent.get("https://api.githubcopilot.com");
mockPool
.intercept({
method: "post",
path: `/chat/completions`,
body: JSON.stringify({
messages: [
{
role: "system",
content: "You are a helpful assistant.",
},
{
role: "user",
content: "What is the capital of France?",
},
],
model: "gpt-4",
}),
})
.reply(400, "Bad Request", {
headers: {
"content-type": "text/plain",
"x-request-id": "<request-id>",
},
});

const result = await prompt("What is the capital of France?", {
token: "secret",
model: "gpt-4",
request: { fetch: fetchMock },
});

t.assert.deepEqual(result, {
message: {
content:
"Sorry, an error occured with the chat completions API. (Status: 400, request ID: <request-id>)",
role: "Sssistant",
},
requestId: "<request-id>",
});
});

suite("getFunctionCalls()", () => {
test("includes function calls", async (t) => {
const tool_calls = [
{
function: {
arguments: '{\n "order_id": "123"\n}',
name: "get_delivery_date",
},
id: "call_Eko8Jz0mgchNOqiJJrrMr8YW",
type: "function",
},
];
const result = getFunctionCalls({
requestId: "<request-id>",
message: {
role: "assistant",
tool_calls,
},
});

t.assert.deepEqual(
result,
tool_calls.map((call) => {
return {
id: call.id,
function: {
name: call.function.name,
arguments: call.function.arguments,
},
};
})
);
});

test("does not include function calls", async (t) => {
const result = getFunctionCalls({
requestId: "<request-id>",
message: {
content: "Hello! How can I assist you today?",
role: "assistant",
},
});

t.assert.deepEqual(result, []);
});
});
});

0 comments on commit be9c476

Please sign in to comment.