Skip to content

Commit

Permalink
Restructure AutoCloseable
Browse files Browse the repository at this point in the history
  • Loading branch information
nomisRev committed Jun 27, 2023
1 parent 3503c2a commit 511d14a
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ data class AIRuntime<A>(val runtime: suspend (block: AI<A>) -> A) {
CoreAIScope(
defaultModel = LLMModel.GPT_3_5_TURBO_16K,
defaultSerializationModel = LLMModel.GPT_3_5_TURBO_FUNCTIONS,
AIClient = openAiClient,
aiClient = openAiClient,
context = vectorStore,
embeddings = embeddings
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import kotlin.jvm.JvmName
class CoreAIScope(
val defaultModel: LLMModel,
val defaultSerializationModel: LLMModel,
val AIClient: AIClient,
val aiClient: AIClient,
val context: VectorStore,
val embeddings: Embeddings,
val maxDeserializationAttempts: Int = 3,
Expand All @@ -37,7 +37,7 @@ class CoreAIScope(
val numberOfPredictions: Int = 1,
val docsInContext: Int = 20,
val minResponseTokens: Int = 500
) : AutoCloseable {
) {

val logger: KLogger = KotlinLogging.logger {}

Expand Down Expand Up @@ -93,7 +93,7 @@ class CoreAIScope(
CoreAIScope(
defaultModel,
defaultSerializationModel,
this@CoreAIScope.AIClient,
this@CoreAIScope.aiClient,
CombinedVectorStore(store, this@CoreAIScope.context),
this@CoreAIScope.embeddings,
)
Expand Down Expand Up @@ -257,7 +257,7 @@ class CoreAIScope(
temperature = temperature,
maxTokens = maxTokens
)
return AIClient.createCompletion(request).choices.map { it.text }
return aiClient.createCompletion(request).choices.map { it.text }
}

private suspend fun callChatEndpoint(
Expand All @@ -283,7 +283,7 @@ class CoreAIScope(
temperature = temperature,
maxTokens = maxTokens
)
return AIClient.createChatCompletion(request).choices.map { it.message.content }
return aiClient.createChatCompletion(request).choices.map { it.message.content }
}

private suspend fun callChatEndpointWithFunctionsSupport(
Expand Down Expand Up @@ -313,7 +313,7 @@ class CoreAIScope(
functions = functions,
functionCall = mapOf("name" to (firstFnName ?: ""))
)
return AIClient.createChatCompletionWithFunctions(request).choices.map {
return aiClient.createChatCompletionWithFunctions(request).choices.map {
it.message.functionCall
}
}
Expand Down Expand Up @@ -518,8 +518,6 @@ class CoreAIScope(
size = size,
user = user
)
return AIClient.createImages(request)
return aiClient.createImages(request)
}

override fun close() = AIClient.close()
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package com.xebia.functional.xef.llm.openai
import com.xebia.functional.xef.llm.openai.images.ImagesGenerationRequest
import com.xebia.functional.xef.llm.openai.images.ImagesGenerationResponse

interface AIClient : AutoCloseable {
interface AIClient {
suspend fun createCompletion(request: CompletionRequest): CompletionResult

suspend fun createChatCompletion(request: ChatCompletionRequest): ChatCompletionResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ class MockOpenAIClient(

override suspend fun createImages(request: ImagesGenerationRequest): ImagesGenerationResponse =
images(request)

override fun close() {}
}

fun nullEmbeddings(request: EmbeddingRequest): EmbeddingResult {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ public static class Art {
}

public static void main(String[] args) {
AIScope.run((scope) -> {
try (AIScope scope = new AIScope()) {
Art art = scope.prompt("ASCII art of a cat dancing", Art.class);
System.out.println(art.art);
return null;
});
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ private static class Story {
}

public static void main(String[] args) {
AIScope.run((scope) -> {
try (AIScope scope = new AIScope()) {
Animal animal = scope.prompt("A unique animal species.", Animal.class);
Invention invention = scope.prompt("A groundbreaking invention from the 20th century.", Invention.class);
String storyPrompt =
Expand All @@ -30,7 +30,6 @@ public static void main(String[] args) {
"2. A groundbreaking invention from the 20th century called " + invention.name + " , invented by " + invention.inventor + " in " + invention.year + ", which serves the purpose of " + invention.purpose + ".";
Story story = scope.prompt(storyPrompt, Story.class);
System.out.println("Story about " + animal.name + " and " + invention.name + ": " + story.story);
return null;
});
}
}
}
58 changes: 27 additions & 31 deletions java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import com.fasterxml.jackson.module.jsonSchema.JsonSchema;
import com.fasterxml.jackson.module.jsonSchema.JsonSchemaGenerator;
import com.xebia.functional.loom.LoomAdapter;
import com.xebia.functional.xef.AIError;
import com.xebia.functional.xef.auto.CoreAIScope;
import com.xebia.functional.xef.embeddings.Embeddings;
import com.xebia.functional.xef.embeddings.OpenAIEmbeddings;
Expand All @@ -27,46 +26,43 @@
import java.util.List;

public class AIScope implements AutoCloseable {
private CoreAIScope coreAIScope;
private ObjectMapper om;
private JsonSchemaGenerator schemaGen;

public AIScope(CoreAIScope coreAIScope) {
this.coreAIScope = coreAIScope;
this.om = new ObjectMapper();
private final CoreAIScope scope;
private final ObjectMapper om;
private final JsonSchemaGenerator schemaGen;
private final KtorOpenAIClient client;
private final Embeddings embeddings;
private final VectorStore vectorStore;

public AIScope(ObjectMapper om, OpenAIConfig config) {
this.om = om;
this.schemaGen = new JsonSchemaGenerator(om);
this.client = new KtorOpenAIClient(config);
this.embeddings = new OpenAIEmbeddings(config, client);
this.vectorStore = new LocalVectorStore(embeddings);
this.scope = new CoreAIScope(LLMModel.getGPT_3_5_TURBO(), LLMModel.getGPT_3_5_TURBO_FUNCTIONS(), client, vectorStore, embeddings, 3, "user", false, 0.4, 1, 20, 500);
}

public AIScope(CoreAIScope coreAIScope, ObjectMapper om) {
this.coreAIScope = coreAIScope;
this.om = om;
this.schemaGen = new JsonSchemaGenerator(om);
public AIScope(ObjectMapper om) {
this(om, new OpenAIConfig());
}

public static <T> T run(Function1<AIScope, T> block) {
OpenAIConfig config = new OpenAIConfig();
KtorOpenAIClient client = new KtorOpenAIClient(config);
try {
Embeddings embeddings = new OpenAIEmbeddings(config, client);
VectorStore vectorStore = new LocalVectorStore(embeddings);
CoreAIScope scope = new CoreAIScope(LLMModel.getGPT_3_5_TURBO(), LLMModel.getGPT_3_5_TURBO_FUNCTIONS(), client, vectorStore, embeddings, 3, "user", false, 0.4, 1, 20, 500);
return block.invoke(new AIScope(scope));
} finally {
client.close();
}
public AIScope(OpenAIConfig config) {
this(new ObjectMapper(), config);
}

public AIScope() {
this(new ObjectMapper(), new OpenAIConfig());
}

private <T> T undefined() {
throw new RuntimeException("Method is undefined");
}

public <A> A prompt(String prompt, Class<A> cls) {
return prompt(prompt, cls, coreAIScope.getMaxDeserializationAttempts(), coreAIScope.getDefaultSerializationModel(), coreAIScope.getUser(), coreAIScope.getEcho(), coreAIScope.getNumberOfPredictions(), coreAIScope.getTemperature(), coreAIScope.getDocsInContext(), coreAIScope.getMinResponseTokens());
return prompt(prompt, cls, scope.getMaxDeserializationAttempts(), scope.getDefaultSerializationModel(), scope.getUser(), scope.getEcho(), scope.getNumberOfPredictions(), scope.getTemperature(), scope.getDocsInContext(), scope.getMinResponseTokens());
}

public <A> A prompt(String prompt, Class<A> cls, Integer maxAttempts, LLMModel llmModel, String user, Boolean echo, Integer n, Double temperature, Integer bringFromContext, Integer minResponseTokens) {
ObjectMapper om = new ObjectMapper();

Function1<? super String, ? extends A> decoder = (json) -> {
try {
return om.readValue(json, cls);
Expand All @@ -91,23 +87,23 @@ public <A> A prompt(String prompt, Class<A> cls, Integer maxAttempts, LLMModel l
);

try {
return LoomAdapter.apply((continuation) -> coreAIScope.<A>promptWithSerializer(prompt, functions, decoder, maxAttempts, llmModel, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation));
return LoomAdapter.apply((continuation) -> scope.<A>promptWithSerializer(prompt, functions, decoder, maxAttempts, llmModel, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

public List<String> promptMessage(String prompt, LLMModel llmModel, List<CFunction> functions, String user, Boolean echo, Integer n, Double temperature, Integer bringFromContext, Integer minResponseTokens) {
try {
return LoomAdapter.apply((continuation) -> coreAIScope.promptMessage(prompt, llmModel, functions, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation));
return LoomAdapter.apply((continuation) -> scope.promptMessage(prompt, llmModel, functions, user, echo, n, temperature, bringFromContext, minResponseTokens, continuation));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
}

public <T> T contextScope(List<String> docs) {
try {
return LoomAdapter.apply(continuation -> coreAIScope.contextScopeWithDocs(docs, undefined(), continuation));
return LoomAdapter.apply(continuation -> scope.contextScopeWithDocs(docs, undefined(), continuation));
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
Expand All @@ -131,7 +127,7 @@ public List<String> pdf(File file, TextSplitter splitter) {

public List<String> images(String prompt, String user, String size, Integer bringFromContext, Integer n) {
try {
ImagesGenerationResponse response = LoomAdapter.apply(continuation -> coreAIScope.images(prompt, user, n, size, bringFromContext, continuation));
ImagesGenerationResponse response = LoomAdapter.apply(continuation -> scope.images(prompt, user, n, size, bringFromContext, continuation));

return CollectionsKt.map(response.getData(), ImageGenerationUrl::getUrl);
} catch (InterruptedException e) {
Expand All @@ -142,6 +138,6 @@ public List<String> images(String prompt, String user, String size, Integer brin

@Override
public void close() {
coreAIScope.close();
client.close();
}
}

0 comments on commit 511d14a

Please sign in to comment.