Skip to content
This repository has been archived by the owner on Jun 6, 2024. It is now read-only.

Move Function Arguments to String #422

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.theokanning.openai.completion.chat;

import com.fasterxml.jackson.databind.JsonNode;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
Expand All @@ -18,6 +17,6 @@ public class ChatFunctionCall {
/**
* The arguments of the call produced by the model, represented as a JsonNode for easy manipulation.
*/
JsonNode arguments;
String arguments;

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
import lombok.AllArgsConstructor;
import lombok.Getter;

import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
* Token calculation tool class
Expand Down Expand Up @@ -173,12 +177,12 @@ public static int tokens(String modelName, List<ChatMessage> messages) {
Encoding encoding = getEncoding(modelName);
int tokensPerMessage = 0;
int tokensPerName = 0;
//3.5统一处理
//3.5
if (modelName.equals("gpt-3.5-turbo-0301") || modelName.equals("gpt-3.5-turbo")) {
tokensPerMessage = 4;
tokensPerName = -1;
}
//4.0统一处理
//4.0
if (modelName.equals("gpt-4") || modelName.equals("gpt-4-0314")) {
tokensPerMessage = 3;
tokensPerName = 1;
Expand Down
28 changes: 21 additions & 7 deletions example/src/main/java/example/OpenAiApiDynamicFunctionExample.java
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
package example;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.theokanning.openai.completion.chat.*;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.completion.chat.ChatFunctionDynamic;
import com.theokanning.openai.completion.chat.ChatFunctionProperty;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.OpenAiService;

import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Scanner;

public class OpenAiApiDynamicFunctionExample {

static ObjectMapper mapper = new ObjectMapper();
private static JsonNode getWeather(String location, String unit) {
ObjectMapper mapper = new ObjectMapper();

ObjectNode response = mapper.createObjectNode();
response.put("location", location);
response.put("unit", unit);
Expand All @@ -20,7 +33,7 @@ private static JsonNode getWeather(String location, String unit) {
return response;
}

public static void main(String... args) {
public static void main(String... args) throws JsonProcessingException {
String token = System.getenv("OPENAI_TOKEN");
OpenAiService service = new OpenAiService(token);

Expand Down Expand Up @@ -68,8 +81,9 @@ public static void main(String... args) {
ChatFunctionCall functionCall = responseMessage.getFunctionCall();
if (functionCall != null) {
if (functionCall.getName().equals("get_weather")) {
String location = functionCall.getArguments().get("location").asText();
String unit = functionCall.getArguments().get("unit").asText();
JsonNode arguments = mapper.readTree(functionCall.getArguments());
String location = arguments.get("location").asText();
String unit = arguments.get("unit").asText();
JsonNode weather = getWeather(location, unit);
ChatMessage weatherMessage = new ChatMessage(ChatMessageRole.FUNCTION.value(), weather.toString(), "get_weather");
messages.add(weatherMessage);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
package example;

import com.theokanning.openai.completion.chat.*;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.service.FunctionExecutor;
import com.theokanning.openai.service.OpenAiService;
import example.OpenAiApiFunctionsExample.Weather;
import example.OpenAiApiFunctionsExample.WeatherResponse;
import io.reactivex.Flowable;

import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Random;
import java.util.Scanner;
import java.util.concurrent.atomic.AtomicBoolean;

public class OpenAiApiFunctionsWithStreamExample {
Expand All @@ -34,7 +43,7 @@ public static void main(String... args) {
while (true) {
ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest
.builder()
.model("gpt-3.5-turbo-0613")
.model("gpt-4-1106-preview")
.messages(messages)
.functions(functionExecutor.getFunctions())
.functionCall(ChatCompletionRequest.ChatCompletionRequestFunctionCall.of("auto"))
Expand All @@ -48,6 +57,7 @@ public static void main(String... args) {
ChatMessage chatMessage = service.mapStreamToAccumulator(flowable)
.doOnNext(accumulator -> {
if (accumulator.isFunctionCall()) {
System.out.println("Trying to execute " + accumulator.getAccumulatedChatFunctionCall().getArguments());
if (isFirst.getAndSet(false)) {
System.out.println("Executing function " + accumulator.getAccumulatedChatFunctionCall().getName() + "...");
}
Expand Down Expand Up @@ -83,4 +93,4 @@ public static void main(String... args) {
}
}

}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;

import java.util.*;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

public class FunctionExecutor {

Expand Down Expand Up @@ -81,8 +85,8 @@ public <T> T execute(ChatFunctionCall call) {
ChatFunction function = FUNCTIONS.get(call.getName());
Object obj;
try {
JsonNode arguments = call.getArguments();
obj = MAPPER.readValue(arguments instanceof TextNode ? arguments.asText() : arguments.toPrettyString(), function.getParametersClass());
String arguments = call.getArguments();
obj = MAPPER.readValue(arguments, function.getParametersClass());
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,34 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.fasterxml.jackson.databind.node.TextNode;
import com.theokanning.openai.*;
import com.theokanning.openai.assistants.*;
import com.theokanning.openai.audio.*;
import com.theokanning.openai.DeleteResult;
import com.theokanning.openai.ListSearchParameters;
import com.theokanning.openai.OpenAiError;
import com.theokanning.openai.OpenAiHttpException;
import com.theokanning.openai.OpenAiResponse;
import com.theokanning.openai.assistants.Assistant;
import com.theokanning.openai.assistants.AssistantFile;
import com.theokanning.openai.assistants.AssistantFileRequest;
import com.theokanning.openai.assistants.AssistantRequest;
import com.theokanning.openai.assistants.ModifyAssistantRequest;
import com.theokanning.openai.audio.CreateSpeechRequest;
import com.theokanning.openai.audio.CreateTranscriptionRequest;
import com.theokanning.openai.audio.CreateTranslationRequest;
import com.theokanning.openai.audio.TranscriptionResult;
import com.theokanning.openai.audio.TranslationResult;
import com.theokanning.openai.billing.BillingUsage;
import com.theokanning.openai.billing.Subscription;
import com.theokanning.openai.client.OpenAiApi;
import com.theokanning.openai.completion.CompletionChunk;
import com.theokanning.openai.completion.CompletionRequest;
import com.theokanning.openai.completion.CompletionResult;
import com.theokanning.openai.completion.chat.*;
import com.theokanning.openai.completion.chat.ChatCompletionChunk;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatCompletionResult;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.completion.chat.ChatMessage;
import com.theokanning.openai.completion.chat.ChatMessageRole;
import com.theokanning.openai.edit.EditRequest;
import com.theokanning.openai.edit.EditResult;
import com.theokanning.openai.embedding.EmbeddingRequest;
Expand Down Expand Up @@ -48,7 +65,12 @@
import io.reactivex.BackpressureStrategy;
import io.reactivex.Flowable;
import io.reactivex.Single;
import okhttp3.*;
import okhttp3.ConnectionPool;
import okhttp3.MediaType;
import okhttp3.MultipartBody;
import okhttp3.OkHttpClient;
import okhttp3.RequestBody;
import okhttp3.ResponseBody;
import retrofit2.Call;
import retrofit2.HttpException;
import retrofit2.Retrofit;
Expand Down Expand Up @@ -584,7 +606,6 @@ public static ObjectMapper defaultObjectMapper() {
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class);
mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class);
return mapper;
}

Expand Down Expand Up @@ -617,8 +638,8 @@ public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatComp
functionCall.setName((functionCall.getName() == null ? "" : functionCall.getName()) + namePart);
}
if (messageChunk.getFunctionCall().getArguments() != null) {
String argumentsPart = messageChunk.getFunctionCall().getArguments() == null ? "" : messageChunk.getFunctionCall().getArguments().asText();
functionCall.setArguments(new TextNode((functionCall.getArguments() == null ? "" : functionCall.getArguments().asText()) + argumentsPart));
String argumentsPart = messageChunk.getFunctionCall().getArguments() == null ? "" : messageChunk.getFunctionCall().getArguments();
functionCall.setArguments((functionCall.getArguments() == null ? "" : functionCall.getArguments()) + argumentsPart);
}
accumulatedMessage.setFunctionCall(functionCall);
} else {
Expand All @@ -627,7 +648,6 @@ public Flowable<ChatMessageAccumulator> mapStreamToAccumulator(Flowable<ChatComp

if (chunk.getChoices().get(0).getFinishReason() != null) { // last
if (functionCall.getArguments() != null) {
functionCall.setArguments(mapper.readTree(functionCall.getArguments().asText()));
accumulatedMessage.setFunctionCall(functionCall);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.PropertyNamingStrategy;
import com.theokanning.openai.ListSearchParameters;
import com.theokanning.openai.OpenAiResponse;
import com.theokanning.openai.assistants.Assistant;
import com.theokanning.openai.assistants.AssistantFunction;
Expand All @@ -15,15 +14,12 @@
import com.theokanning.openai.assistants.Tool;
import com.theokanning.openai.completion.chat.ChatCompletionRequest;
import com.theokanning.openai.completion.chat.ChatFunction;
import com.theokanning.openai.completion.chat.ChatFunctionCall;
import com.theokanning.openai.messages.Message;
import com.theokanning.openai.messages.MessageRequest;
import com.theokanning.openai.runs.RequiredAction;
import com.theokanning.openai.runs.Run;
import com.theokanning.openai.runs.RunCreateRequest;
import com.theokanning.openai.runs.RunStep;
import com.theokanning.openai.runs.SubmitToolOutputRequestItem;
import com.theokanning.openai.runs.SubmitToolOutputs;
import com.theokanning.openai.runs.SubmitToolOutputsRequest;
import com.theokanning.openai.runs.ToolCall;
import com.theokanning.openai.threads.Thread;
Expand All @@ -35,9 +31,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;

class AssistantFunctionTest {
Expand All @@ -53,8 +47,7 @@ void createRetrieveRun() throws JsonProcessingException {
mapper.setPropertyNamingStrategy(PropertyNamingStrategy.SNAKE_CASE);
mapper.addMixIn(ChatFunction.class, ChatFunctionMixIn.class);
mapper.addMixIn(ChatCompletionRequest.class, ChatCompletionRequestMixIn.class);
mapper.addMixIn(ChatFunctionCall.class, ChatFunctionCallMixIn.class);


String funcDef = "{\n" +
" \"type\": \"object\",\n" +
" \"properties\": {\n" +
Expand All @@ -79,8 +72,8 @@ void createRetrieveRun() throws JsonProcessingException {
List<Tool> toolList = new ArrayList<>();
Tool funcTool = new Tool(AssistantToolsEnum.FUNCTION, function);
toolList.add(funcTool);


AssistantRequest assistantRequest = AssistantRequest.builder()
.model(TikTokensUtil.ModelEnum.GPT_4_1106_preview.getName())
.name("MATH_TUTOR")
Expand All @@ -107,8 +100,8 @@ void createRetrieveRun() throws JsonProcessingException {
assertNotNull(run);

Run retrievedRun = service.retrieveRun(thread.getId(), run.getId());
while (!(retrievedRun.getStatus().equals("completed"))
&& !(retrievedRun.getStatus().equals("failed"))
while (!(retrievedRun.getStatus().equals("completed"))
&& !(retrievedRun.getStatus().equals("failed"))
&& !(retrievedRun.getStatus().equals("requires_action"))){
retrievedRun = service.retrieveRun(thread.getId(), run.getId());
}
Expand Down Expand Up @@ -142,7 +135,7 @@ void createRetrieveRun() throws JsonProcessingException {
List<Message> messages = response.getData();

System.out.println(mapper.writeValueAsString(messages));

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
import com.theokanning.openai.DeleteResult;
import com.theokanning.openai.ListSearchParameters;
import com.theokanning.openai.OpenAiResponse;
import com.theokanning.openai.assistants.*;
import com.theokanning.openai.assistants.Assistant;
import com.theokanning.openai.assistants.AssistantFile;
import com.theokanning.openai.assistants.AssistantFileRequest;
import com.theokanning.openai.assistants.AssistantRequest;
import com.theokanning.openai.assistants.AssistantToolsEnum;
import com.theokanning.openai.assistants.ModifyAssistantRequest;
import com.theokanning.openai.assistants.Tool;
import com.theokanning.openai.file.File;
import com.theokanning.openai.utils.TikTokensUtil;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.Test;

import java.util.Collections;
import java.util.List;

import static org.junit.jupiter.api.Assertions.*;

Expand All @@ -19,7 +24,7 @@ public class AssistantTest {
public static final String MATH_TUTOR = "Math Tutor";
public static final String ASSISTANT_INSTRUCTION = "You are a personal Math Tutor.";

static String token = System.getenv("OPENAI_TOKEN");;
static String token = System.getenv("OPENAI_TOKEN");

static OpenAiService service = new OpenAiService(token);

Expand Down Expand Up @@ -105,9 +110,7 @@ static void clean() {
.limit(100)
.build();
OpenAiResponse<Assistant> assistantListAssistant = service.listAssistants(queryFilter);
assistantListAssistant.getData().forEach(assistant ->{
service.deleteAssistant(assistant.getId());
});
assistantListAssistant.getData().forEach(assistant -> service.deleteAssistant(assistant.getId()));
}

private static File uploadAssistantFile() {
Expand Down Expand Up @@ -137,7 +140,7 @@ private static void validateAssistantResponse(Assistant assistantResponse) {
assertNotNull(assistantResponse.getId());
assertNotNull(assistantResponse.getCreatedAt());
assertNotNull(assistantResponse.getObject());
assertEquals(assistantResponse.getTools().get(0).getType(), AssistantToolsEnum.CODE_INTERPRETER);
assertEquals(AssistantToolsEnum.CODE_INTERPRETER, assistantResponse.getTools().get(0).getType());
assertEquals(MATH_TUTOR, assistantResponse.getName());
}
}
Loading