Skip to content

Commit

Permalink
Add Google Generative AI provider for ai/core functions. (#1261)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Apr 2, 2024
1 parent 1088e04 commit 2b991c4
Show file tree
Hide file tree
Showing 22 changed files with 983 additions and 9 deletions.
5 changes: 5 additions & 0 deletions .changeset/cool-experts-pull.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'ai': patch
---

Add Google Generative AI provider for ai/core functions.
3 changes: 1 addition & 2 deletions examples/ai-core/.env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
OPENAI_API_KEY=""
MISTRAL_API_KEY=""
PERPLEXITY_API_KEY=""
FIREWORKS_API_KEY=""
GOOGLE_GENERATIVE_AI_API_KEY=""
26 changes: 26 additions & 0 deletions examples/ai-core/src/generate-text/google-multimodal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { experimental_generateText } from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';
import fs from 'node:fs';

dotenv.config();

async function main() {
const result = await experimental_generateText({
model: google.generativeAI('models/gemini-pro-vision'),
maxTokens: 512,
messages: [
{
role: 'user',
content: [
{ type: 'text', text: 'Describe the image in detail.' },
{ type: 'image', image: fs.readFileSync('./data/comic-cat.png') },
],
},
],
});

console.log(result.text);
}

main();
60 changes: 60 additions & 0 deletions examples/ai-core/src/generate-text/google-tool-call.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import { experimental_generateText, tool } from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';
import { z } from 'zod';
import { weatherTool } from '../tools/weather-tool';

dotenv.config();

async function main() {
const result = await experimental_generateText({
model: google.generativeAI('models/gemini-pro'),
maxTokens: 512,
tools: {
weather: weatherTool,
cityAttractions: tool({
parameters: z.object({ city: z.string() }),
}),
},
prompt:
'What is the weather in San Francisco and what attractions should I visit?',
});

// typed tool calls:
for (const toolCall of result.toolCalls) {
switch (toolCall.toolName) {
case 'cityAttractions': {
toolCall.args.city; // string
break;
}

case 'weather': {
toolCall.args.location; // string
break;
}
}
}

// typed tool results for tools with execute method:
for (const toolResult of result.toolResults) {
switch (toolResult.toolName) {
// NOT AVAILABLE (NO EXECUTE METHOD)
// case 'cityAttractions': {
// toolResult.args.city; // string
// toolResult.result;
// break;
// }

case 'weather': {
toolResult.args.location; // string
toolResult.result.location; // string
toolResult.result.temperature; // number
break;
}
}
}

console.log(JSON.stringify(result, null, 2));
}

main();
19 changes: 19 additions & 0 deletions examples/ai-core/src/generate-text/google.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { experimental_generateText } from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await experimental_generateText({
model: google.generativeAI('models/gemini-pro'),
prompt: 'Invent a new holiday and describe its traditions.',
});

console.log(result.text);
console.log();
console.log('Token usage:', result.usage);
console.log('Finish reason:', result.finishReason);
}

main();
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ async function main() {
{
type: 'image',
image: new URL(
'https://raw.githubusercontent.com/vercel/ai/v3.1-canary/examples/ai-core/data/comic-cat.png',
'https://github.com/vercel/ai/blob/main/examples/ai-core/data/comic-cat.png?raw=true',
),
},
],
Expand Down
90 changes: 90 additions & 0 deletions examples/ai-core/src/stream-text/google-chatbot-with-tools.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import {
ExperimentalMessage,
ToolCallPart,
ToolResultPart,
experimental_streamText,
} from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';
import * as readline from 'node:readline/promises';
import { weatherTool } from '../tools/weather-tool';

dotenv.config();

const terminal = readline.createInterface({
input: process.stdin,
output: process.stdout,
});

const messages: ExperimentalMessage[] = [];

async function main() {
let toolResponseAvailable = false;

while (true) {
if (!toolResponseAvailable) {
const userInput = await terminal.question('You: ');
messages.push({ role: 'user', content: userInput });
}

const result = await experimental_streamText({
model: google.generativeAI('models/gemini-pro'),
tools: { weatherTool },
system: `You are a helpful, respectful and honest assistant.`,
messages,
});

toolResponseAvailable = false;
let fullResponse = '';
const toolCalls: ToolCallPart[] = [];
const toolResponses: ToolResultPart[] = [];

for await (const delta of result.fullStream) {
switch (delta.type) {
case 'text-delta': {
if (fullResponse.length === 0) {
process.stdout.write('\nAssistant: ');
}

fullResponse += delta.textDelta;
process.stdout.write(delta.textDelta);
break;
}

case 'tool-call': {
toolCalls.push(delta);

process.stdout.write(
`\nTool call: '${delta.toolName}' ${JSON.stringify(delta.args)}`,
);
break;
}

case 'tool-result': {
toolResponses.push(delta);

process.stdout.write(
`\nTool response: '${delta.toolName}' ${JSON.stringify(
delta.result,
)}`,
);
break;
}
}
}
process.stdout.write('\n\n');

messages.push({
role: 'assistant',
content: [{ type: 'text', text: fullResponse }, ...toolCalls],
});

if (toolResponses.length > 0) {
messages.push({ role: 'tool', content: toolResponses });
}

toolResponseAvailable = toolCalls.length > 0;
}
}

main().catch(console.error);
39 changes: 39 additions & 0 deletions examples/ai-core/src/stream-text/google-chatbot.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import { ExperimentalMessage, experimental_streamText } from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';
import * as readline from 'node:readline/promises';

dotenv.config();

const terminal = readline.createInterface({
input: process.stdin,
output: process.stdout,
});

const messages: ExperimentalMessage[] = [];

async function main() {
while (true) {
const userInput = await terminal.question('You: ');

messages.push({ role: 'user', content: userInput });

const result = await experimental_streamText({
model: google.generativeAI('models/gemini-pro'),
system: `You are a helpful, respectful and honest assistant.`,
messages,
});

let fullResponse = '';
process.stdout.write('\nAssistant: ');
for await (const delta of result.textStream) {
fullResponse += delta;
process.stdout.write(delta);
}
process.stdout.write('\n\n');

messages.push({ role: 'assistant', content: fullResponse });
}
}

main().catch(console.error);
80 changes: 80 additions & 0 deletions examples/ai-core/src/stream-text/google-fullstream.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import { experimental_streamText } from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';
import { z } from 'zod';
import { weatherTool } from '../tools/weather-tool';

dotenv.config();

async function main() {
const result = await experimental_streamText({
model: google.generativeAI('models/gemini-pro'),
tools: {
weather: weatherTool,
cityAttractions: {
parameters: z.object({ city: z.string() }),
},
},
prompt: 'What is the weather in San Francisco?',
});

for await (const part of result.fullStream) {
switch (part.type) {
case 'text-delta': {
console.log('Text delta:', part.textDelta);
break;
}

case 'tool-call': {
switch (part.toolName) {
case 'cityAttractions': {
console.log('TOOL CALL cityAttractions');
console.log(`city: ${part.args.city}`); // string
break;
}

case 'weather': {
console.log('TOOL CALL weather');
console.log(`location: ${part.args.location}`); // string
break;
}
}

break;
}

case 'tool-result': {
switch (part.toolName) {
// NOT AVAILABLE (NO EXECUTE METHOD)
// case 'cityAttractions': {
// console.log('TOOL RESULT cityAttractions');
// console.log(`city: ${part.args.city}`); // string
// console.log(`result: ${part.result}`);
// break;
// }

case 'weather': {
console.log('TOOL RESULT weather');
console.log(`location: ${part.args.location}`); // string
console.log(`temperature: ${part.result.temperature}`); // number
break;
}
}

break;
}

case 'finish': {
console.log('Finish reason:', part.finishReason);
console.log('Usage:', part.usage);
break;
}

case 'error':
console.error('Error:', part.error);
break;
}
}
}

main();
18 changes: 18 additions & 0 deletions examples/ai-core/src/stream-text/google.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import { experimental_streamText } from 'ai';
import { google } from 'ai/google';
import dotenv from 'dotenv';

dotenv.config();

async function main() {
const result = await experimental_streamText({
model: google.generativeAI('models/gemini-pro'),
prompt: 'Invent a new holiday and describe its traditions.',
});

for await (const textPart of result.textStream) {
process.stdout.write(textPart);
}
}

main();
Loading

0 comments on commit 2b991c4

Please sign in to comment.