From 46bbc76800cd5130f17b849610b7efec473b921a Mon Sep 17 00:00:00 2001 From: Ryota-Kawamura Date: Sat, 28 Oct 2023 15:06:47 +0900 Subject: [PATCH] OpenAI Function Calling in LangChain --- L3-function-calling-student.ipynb | 808 ++++++++++++++++++++++++++++++ 1 file changed, 808 insertions(+) create mode 100644 L3-function-calling-student.ipynb diff --git a/L3-function-calling-student.ipynb b/L3-function-calling-student.ipynb new file mode 100644 index 0000000..eadf781 --- /dev/null +++ b/L3-function-calling-student.ipynb @@ -0,0 +1,808 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "16de7336", + "metadata": {}, + "source": [ + "# OpenAI Function Calling In LangChain" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "cb41f5f4-df8d-4d04-9eaa-193b8c29b00b", + "metadata": { + "height": 115, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import openai\n", + "\n", + "from dotenv import load_dotenv, find_dotenv\n", + "_ = load_dotenv(find_dotenv()) # read local .env file\n", + "openai.api_key = os.environ['OPENAI_API_KEY']" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "aa1dddf9-8e44-4454-9d44-f8372cccf5ac", + "metadata": { + "height": 47 + }, + "outputs": [], + "source": [ + "from typing import List\n", + "from pydantic import BaseModel, Field" + ] + }, + { + "cell_type": "markdown", + "id": "ad68931a-f806-4ea9-969c-93b3902baf9b", + "metadata": {}, + "source": [ + "## Pydantic Syntax\n", + "\n", + "Pydantic data classes are a blend of Python's data classes with the validation power of Pydantic. \n", + "\n", + "They offer a concise way to define data structures while ensuring that the data adheres to specified types and constraints.\n", + "\n", + "In standard python you would create a class like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e1557226-36e2-484b-a2fb-bb7e3180342c", + "metadata": { + "height": 98 + }, + "outputs": [], + "source": [ + "class User:\n", + " def __init__(self, name: str, age: int, email: str):\n", + " self.name = name\n", + " self.age = age\n", + " self.email = email" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11b9b584-74dc-49b8-a7fe-3865368774e9", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "foo = User(name=\"Joe\",age=32, email=\"joe@gmail.com\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0f6f9e9c-b83a-4859-8e65-e6488e05a071", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Joe'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo.name" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "10a6a0de-d7dc-414d-baaf-fa43c6d1f410", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "foo = User(name=\"Joe\",age=\"bar\", email=\"joe@gmail.com\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "613b7b12-f061-44bc-989d-433cab609164", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'bar'" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo.age" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c541cb8d-fc55-4c94-a04f-a877cccf10ec", + "metadata": { + "height": 81 + }, + "outputs": [], + "source": [ + "class pUser(BaseModel):\n", + " name: str\n", + " age: int\n", + " email: str" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "27394d22-73e3-4918-9bdf-18cd7c973942", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "foo_p = pUser(name=\"Jane\", age=32, email=\"jane@gmail.com\")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "49f25241-ff47-454f-bac4-ba20ab937d70", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'Jane'" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "foo_p.name" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "37030df3-ec11-4523-ac66-b88f90099d1b", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "ename": "ValidationError", + "evalue": "1 validation error for pUser\nage\n value is not a valid integer (type=type_error.integer)", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mValidationError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[11], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m foo_p \u001b[38;5;241m=\u001b[39m \u001b[43mpUser\u001b[49m\u001b[43m(\u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mJane\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mage\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mbar\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43memail\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mjane@gmail.com\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\.pyenv\\pyenv-win\\versions\\3.11.3\\Lib\\site-packages\\pydantic\\main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[1;34m()\u001b[0m\n", + "\u001b[1;31mValidationError\u001b[0m: 1 validation error for pUser\nage\n value is not a valid integer (type=type_error.integer)" + ] + } + ], + "source": [ + "foo_p = pUser(name=\"Jane\", age=\"bar\", email=\"jane@gmail.com\")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "911e7677-cc5d-4957-b6c3-b3ba1493de33", + "metadata": { + "height": 47 + }, + "outputs": [], + "source": [ + "class Class(BaseModel):\n", + " students: List[pUser]" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "14920e50-688e-4dd4-9207-9b59c6b018c6", + "metadata": { + "height": 64 + }, + "outputs": [], + "source": [ + "obj = Class(\n", + " students=[pUser(name=\"Jane\", age=32, email=\"jane@gmail.com\")]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "cede1035-7581-4203-bab7-b8e6363c931f", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Class(students=[pUser(name='Jane', age=32, email='jane@gmail.com')])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "obj" + ] + }, + { + "cell_type": "markdown", + "id": "b9c12cef-3a2d-46da-9c45-9a117e10f4a4", + "metadata": {}, + "source": [ + "## Pydantic to OpenAI function definition\n" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "617ceea9-009f-4325-adae-85ab29fccd68", + "metadata": { + "height": 64 + }, + "outputs": [], + "source": [ + "class WeatherSearch(BaseModel):\n", + " \"\"\"Call this with an airport code to get the weather at that airport\"\"\"\n", + " airport_code: str = Field(description=\"airport code to get weather for\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "b22a438c-6692-47f9-9e00-1a95d04c6dd3", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "from langchain.utils.openai_functions import convert_pydantic_to_openai_function" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "97e152c4-8d04-4a02-b363-ee7691f60e31", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "weather_function = convert_pydantic_to_openai_function(WeatherSearch)" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "f36b041e-bd28-4e25-a0c1-fbeeeee4ae53", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': 'WeatherSearch',\n", + " 'description': 'Call this with an airport code to get the weather at that airport',\n", + " 'parameters': {'title': 'WeatherSearch',\n", + " 'description': 'Call this with an airport code to get the weather at that airport',\n", + " 'type': 'object',\n", + " 'properties': {'airport_code': {'title': 'Airport Code',\n", + " 'description': 'airport code to get weather for',\n", + " 'type': 'string'}},\n", + " 'required': ['airport_code']}}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "weather_function" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c5d7c573-2f84-441d-ab73-a3dd263318d4", + "metadata": { + "height": 47 + }, + "outputs": [], + "source": [ + "class WeatherSearch1(BaseModel):\n", + " airport_code: str = Field(description=\"airport code to get weather for\")" + ] + }, + { + "cell_type": "markdown", + "id": "3d99b688-a9a7-4446-977f-07918a5d93e1", + "metadata": {}, + "source": [ + "Note: The next cell is expected to generate an error." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "bb08f095-8190-41c5-b49c-e3580cedf992", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'description'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mKeyError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[1;32mIn[20], line 1\u001b[0m\n\u001b[1;32m----> 1\u001b[0m \u001b[43mconvert_pydantic_to_openai_function\u001b[49m\u001b[43m(\u001b[49m\u001b[43mWeatherSearch1\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[1;32m~\\.pyenv\\pyenv-win\\versions\\3.11.3\\Lib\\site-packages\\langchain\\utils\\openai_functions.py:29\u001b[0m, in \u001b[0;36mconvert_pydantic_to_openai_function\u001b[1;34m(model, name, description)\u001b[0m\n\u001b[0;32m 25\u001b[0m schema \u001b[38;5;241m=\u001b[39m dereference_refs(model\u001b[38;5;241m.\u001b[39mschema())\n\u001b[0;32m 26\u001b[0m schema\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdefinitions\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;28;01mNone\u001b[39;00m)\n\u001b[0;32m 27\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m {\n\u001b[0;32m 28\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mname\u001b[39m\u001b[38;5;124m\"\u001b[39m: name \u001b[38;5;129;01mor\u001b[39;00m schema[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtitle\u001b[39m\u001b[38;5;124m\"\u001b[39m],\n\u001b[1;32m---> 29\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdescription\u001b[39m\u001b[38;5;124m\"\u001b[39m: description \u001b[38;5;129;01mor\u001b[39;00m \u001b[43mschema\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mdescription\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m,\n\u001b[0;32m 30\u001b[0m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mparameters\u001b[39m\u001b[38;5;124m\"\u001b[39m: schema,\n\u001b[0;32m 31\u001b[0m }\n", + "\u001b[1;31mKeyError\u001b[0m: 'description'" + ] + } + ], + "source": [ + "convert_pydantic_to_openai_function(WeatherSearch1)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "ed22668a-e188-45a5-844e-deee62f9bf51", + "metadata": { + "height": 64 + }, + "outputs": [], + "source": [ + "class WeatherSearch2(BaseModel):\n", + " \"\"\"Call this with an airport code to get the weather at that airport\"\"\"\n", + " airport_code: str" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "9e001e87-4338-4720-99b3-9dc4cb3e4faf", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'name': 'WeatherSearch2',\n", + " 'description': 'Call this with an airport code to get the weather at that airport',\n", + " 'parameters': {'title': 'WeatherSearch2',\n", + " 'description': 'Call this with an airport code to get the weather at that airport',\n", + " 'type': 'object',\n", + " 'properties': {'airport_code': {'title': 'Airport Code', 'type': 'string'}},\n", + " 'required': ['airport_code']}}" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "convert_pydantic_to_openai_function(WeatherSearch2)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a35261e8-2d36-43a8-a051-a79bef35c8dd", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "93bda0a6-9206-407a-b0da-966f9442a40c", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "model = ChatOpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "afe342d4-a7ef-49cd-b760-aa9a176d64d5", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n \"airport_code\": \"SFO\"\\n}'}})" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.invoke(\"what is the weather in SF today?\", functions=[weather_function])" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "511e12b6-bcfb-4862-b377-4251de9969ea", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "model_with_function = model.bind(functions=[weather_function])" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "de6241d9-667c-4b97-a50f-95c046fa640c", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n \"airport_code\": \"SFO\"\\n}'}})" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_with_function.invoke(\"what is the weather in sf?\")" + ] + }, + { + "cell_type": "markdown", + "id": "ae78d6dd-bb38-4e55-9b65-0ef9005a52b9", + "metadata": {}, + "source": [ + "## Forcing it to use a function\n", + "\n", + "We can force the model to use a function" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "e7d2fbc0-df39-4e93-a22f-39a6285272b4", + "metadata": { + "height": 47 + }, + "outputs": [], + "source": [ + "model_with_forced_function = model.bind(functions=[weather_function], function_call={\"name\":\"WeatherSearch\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "cd9f4063-9e15-41d7-9cf9-253548534176", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n \"airport_code\": \"SFO\"\\n}'}})" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_with_forced_function.invoke(\"what is the weather in sf?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "314ca7e6-b77c-4b9d-9c93-da6ef3c9c6f8", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n \"airport_code\": \"SFO\"\\n}'}})" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_with_forced_function.invoke(\"hi!\")" + ] + }, + { + "cell_type": "markdown", + "id": "5ac391c3-cd81-4423-a33e-6583ec534850", + "metadata": {}, + "source": [ + "## Using in a chain\n", + "\n", + "We can use this model bound to function in a chain as we normally would" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "c12c86df-d628-4176-9f4e-24fb5a953a5d", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "from langchain.prompts import ChatPromptTemplate" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "83f00dc6-5d22-44a5-a0a0-4fe1ab8167ac", + "metadata": { + "height": 81 + }, + "outputs": [], + "source": [ + "prompt = ChatPromptTemplate.from_messages([\n", + " (\"system\", \"You are a helpful assistant\"),\n", + " (\"user\", \"{input}\")\n", + "])" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "1eee3f5f-2176-4777-8c8a-2a197acc47a7", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "chain = prompt | model_with_function" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "8587907e-b4c3-4acd-9e58-1137047d0fee", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n \"airport_code\": \"SFO\"\\n}'}})" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke({\"input\": \"what is the weather in sf?\"})" + ] + }, + { + "cell_type": "markdown", + "id": "f317408d-de5e-4774-993e-a8ac31a2f5fe", + "metadata": {}, + "source": [ + "## Using multiple functions\n", + "\n", + "Even better, we can pass a set of function and let the LLM decide which to use based on the question context." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "48c8e42e-f84e-4822-b1ee-a9955fa301c4", + "metadata": { + "height": 81 + }, + "outputs": [], + "source": [ + "class ArtistSearch(BaseModel):\n", + " \"\"\"Call this to get the names of songs by a particular artist\"\"\"\n", + " artist_name: str = Field(description=\"name of artist to look up\")\n", + " n: int = Field(description=\"number of results\")" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "9a51599b-ee32-4f74-9925-b6264bf43242", + "metadata": { + "height": 81 + }, + "outputs": [], + "source": [ + "functions = [\n", + " convert_pydantic_to_openai_function(WeatherSearch),\n", + " convert_pydantic_to_openai_function(ArtistSearch),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "e0cc9e3b-ba38-4eff-b285-d02ee5963725", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [ + "model_with_functions = model.bind(functions=functions)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "d01cd501-03b4-4207-b1be-0c33c12a0fa5", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'WeatherSearch', 'arguments': '{\\n \"airport_code\": \"SFO\"\\n}'}})" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_with_functions.invoke(\"what is the weather in sf?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "31599518-7387-4d08-9d68-8f5ba7282e8b", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'function_call': {'name': 'ArtistSearch', 'arguments': '{\\n \"artist_name\": \"taylor swift\",\\n \"n\": 3\\n}'}})" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_with_functions.invoke(\"what are three songs by taylor swift?\")" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "4e6df989-6ea3-48af-b2c0-10978e0f4142", + "metadata": { + "height": 30 + }, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Hello! How can I assist you today?')" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model_with_functions.invoke(\"hi!\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b29653b7-1382-4375-b681-10c51964fff5", + "metadata": { + "height": 30 + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}