From dc0aa2f37df1bb37701b8e8b2463d7af7ab6096e Mon Sep 17 00:00:00 2001 From: Daniel Ivan Gelvez Leon Date: Mon, 24 Jul 2023 12:00:16 -0500 Subject: [PATCH] Java example gpt4all (#253) * Java Examples web search * Add Search to AIScope * Add Search to AIScope * Refactor * Refactor * Add Chat * Refactor * Start java example gpt4all * Add flow collector * Add port to test * Finish java example gpt4all * Refactor * Add Publisher * Refactor * Refactor * Refactor * Refactor * Final Refactor * Final Refactor conversationId --- .../xebia/functional/xef/auto/CoreAIScope.kt | 11 +++ .../com/xebia/functional/xef/llm/Chat.kt | 8 +- examples/java/build.gradle.kts | 1 + .../xef/java/auto/gpt4all/Chat.java | 82 +++++++++++++++++++ gradle/libs.versions.toml | 2 + java/build.gradle.kts | 1 + .../functional/xef/java/auto/AIScope.java | 24 ++++-- .../xef/java/auto/ExecutionContext.java | 24 ++++-- 8 files changed, 134 insertions(+), 19 deletions(-) create mode 100644 examples/java/src/main/java/com/xebia/functional/xef/java/auto/gpt4all/Chat.java diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt index 5a1b5a381..29dbbe306 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/auto/CoreAIScope.kt @@ -12,6 +12,7 @@ import com.xebia.functional.xef.vectorstores.CombinedVectorStore import com.xebia.functional.xef.vectorstores.ConversationId import com.xebia.functional.xef.vectorstores.LocalVectorStore import com.xebia.functional.xef.vectorstores.VectorStore +import kotlinx.coroutines.flow.Flow import kotlin.jvm.JvmName import kotlin.jvm.JvmOverloads import kotlinx.uuid.UUID @@ -130,6 +131,16 @@ constructor( ): List = promptMessages(Prompt(question), context, conversationId, functions, promptConfiguration) + @AiDsl + fun Chat.promptStreaming( + question: String, + context: VectorStore, + conversationId: ConversationId?, + functions: List, + promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS + ): Flow = + promptStreaming(Prompt(question), context, conversationId, functions, promptConfiguration) + /** * Run a [prompt] describes the images you want to generate within the context of [CoreAIScope]. * Returns a [ImagesGenerationResponse] containing time and urls with images generated. diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index d24b9d253..1fe4b2d4f 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -27,7 +27,7 @@ interface Chat : LLM { fun tokensFromMessages(messages: List): Int @AiDsl - suspend fun promptStreaming( + fun promptStreaming( question: String, context: VectorStore, conversationId: ConversationId? = null, @@ -37,13 +37,13 @@ interface Chat : LLM { promptStreaming(Prompt(question), context, conversationId, functions, promptConfiguration) @AiDsl - suspend fun promptStreaming( + fun promptStreaming( prompt: Prompt, context: VectorStore, conversationId: ConversationId? = null, functions: List = emptyList(), promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS - ): Flow { + ): Flow = flow { val memories: List = memories(conversationId, context, promptConfiguration) @@ -79,7 +79,6 @@ interface Chat : LLM { streamToStandardOut = true ) - return flow { val buffer = StringBuilder() createChatCompletions(request) .onEach { @@ -90,7 +89,6 @@ interface Chat : LLM { } .onCompletion { addMemoriesAfterStream(request, conversationId, buffer, context) } .collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) } - } } private suspend fun addMemoriesAfterStream( diff --git a/examples/java/build.gradle.kts b/examples/java/build.gradle.kts index 8a9ae2c7f..be1962a84 100644 --- a/examples/java/build.gradle.kts +++ b/examples/java/build.gradle.kts @@ -7,6 +7,7 @@ plugins { dependencies { implementation(projects.xefJava) + implementation(projects.xefGpt4all) } tasks.withType().configureEach { diff --git a/examples/java/src/main/java/com/xebia/functional/xef/java/auto/gpt4all/Chat.java b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/gpt4all/Chat.java new file mode 100644 index 000000000..06a58291d --- /dev/null +++ b/examples/java/src/main/java/com/xebia/functional/xef/java/auto/gpt4all/Chat.java @@ -0,0 +1,82 @@ +package com.xebia.functional.xef.java.auto.gpt4all; + +import com.xebia.functional.gpt4all.GPT4All; +import com.xebia.functional.gpt4all.Gpt4AllModel; +import com.xebia.functional.xef.auto.PromptConfiguration; +import com.xebia.functional.xef.java.auto.AIScope; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.nio.file.Path; +import java.util.Objects; +import java.util.concurrent.ExecutionException; +import org.reactivestreams.Publisher; +import org.reactivestreams.Subscriber; +import org.reactivestreams.Subscription; + +public class Chat { + public static void main(String[] args) throws ExecutionException, InterruptedException, IOException { + var userDir = System.getProperty("user.dir"); + var path = userDir + "/models/gpt4all/ggml-replit-code-v1-3b.bin"; + + var supportedModels = Gpt4AllModel.Companion.getSupportedModels(); + + supportedModels.forEach(it -> { + var url = (Objects.nonNull(it.getUrl())) ? " - " + it.getUrl() : ""; + System.out.println("🤖 " + it.getName() + url); + }); + + var url = "https://huggingface.co/nomic-ai/ggml-replit-code-v1-3b/resolve/main/ggml-replit-code-v1-3b.bin"; + var modelPath = Path.of(path); + var gpt4all = GPT4All.Companion.invoke(url, modelPath); + + System.out.println("🤖 GPT4All loaded: " + gpt4all); + /** + * Uses internally [HuggingFaceLocalEmbeddings] default of "sentence-transformers", "msmarco-distilbert-dot-v5" + * to provide embeddings for docs in contextScope. + */ + + try (AIScope scope = new AIScope(); + BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) { + + System.out.println("🤖 Context loaded: " + scope.getExec().getContext()); + + System.out.println("\n🤖 Enter your question: "); + + while(true){ + String line = br.readLine(); + if (line.equals("exit")) break; + + var promptConfiguration = new PromptConfiguration.Companion.Builder().docsInContext(2).streamToStandardOut(true).build(); + Publisher answer = scope.promptStreaming(gpt4all, line, promptConfiguration); + + answer.subscribe(new Subscriber() { + StringBuilder answer = new StringBuilder(); + + @Override + public void onSubscribe(Subscription s) { + System.out.print("\n🤖 --> " + s); + s.request(Long.MAX_VALUE); + } + + @Override + public void onNext(String s) { + answer.append(s); + } + + @Override + public void onError(Throwable t) { + System.out.println(t); + } + + @Override + public void onComplete() { + System.out.println("\n🤖 --> " + answer.toString()); + System.out.println("\n🤖 --> Done"); + System.out.println("\n🤖 Enter your question: "); + } + }); + } + } + } +} diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index b9c6ce91b..16a5ae689 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -18,6 +18,7 @@ hikari = "5.0.1" dokka = "1.8.20" logback = "1.4.8" kotlinx-coroutines = "1.7.2" +kotlinx-coroutines-reactive = "1.7.2" scalaMultiversion = "2.0.4" circe = "0.14.5" catsEffect = "3.6-0142603" @@ -45,6 +46,7 @@ arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" } kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" } kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" } +kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines-reactive" } ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" } ktor-client ={ module = "io.ktor:ktor-client-core", version.ref = "ktor" } ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" } diff --git a/java/build.gradle.kts b/java/build.gradle.kts index d4303bcf2..4d9a61920 100644 --- a/java/build.gradle.kts +++ b/java/build.gradle.kts @@ -19,6 +19,7 @@ dependencies { api(libs.jackson.schema) api(libs.jackson.schema.jakarta) api(libs.jakarta.validation) + api(libs.kotlinx.coroutines.reactive) } java { diff --git a/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java b/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java index 8aebed6a1..3e6ed4524 100644 --- a/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java +++ b/java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java @@ -2,7 +2,11 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import com.github.victools.jsonschema.generator.*; +import com.github.victools.jsonschema.generator.OptionPreset; +import com.github.victools.jsonschema.generator.SchemaGenerator; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfig; +import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder; +import com.github.victools.jsonschema.generator.SchemaVersion; import com.github.victools.jsonschema.module.jakarta.validation.JakartaValidationModule; import com.github.victools.jsonschema.module.jakarta.validation.JakartaValidationOption; import com.xebia.functional.xef.agents.Search; @@ -20,14 +24,15 @@ import com.xebia.functional.xef.sql.SQL; import com.xebia.functional.xef.textsplitters.TextSplitter; import com.xebia.functional.xef.vectorstores.VectorStore; -import kotlin.collections.CollectionsKt; -import kotlin.jvm.functions.Function1; -import kotlinx.coroutines.future.FutureKt; - import java.io.File; import java.util.Collections; import java.util.List; import java.util.concurrent.CompletableFuture; +import kotlin.collections.CollectionsKt; +import kotlin.jvm.functions.Function1; +import kotlinx.coroutines.future.FutureKt; +import kotlinx.coroutines.reactive.ReactiveFlowKt; +import org.reactivestreams.Publisher; public class AIScope implements AutoCloseable { private final CoreAIScope scope; @@ -49,6 +54,10 @@ public AIScope(ObjectMapper om, ExecutionContext executionContext) { this.scope = executionContext.getCoreScope(); } + public ExecutionContext getExec() { + return exec; + } + public AIScope(ExecutionContext executionContext) { this(new ObjectMapper(), executionContext); } @@ -99,6 +108,10 @@ public CompletableFuture> promptMessages(Chat llmModel, String prom return exec.future(continuation -> scope.promptMessages(llmModel, prompt, functions, promptConfiguration, continuation)); } + public Publisher promptStreaming(Chat gpt4all, String line, PromptConfiguration promptConfiguration) { + return ReactiveFlowKt.asPublisher(scope.promptStreaming(gpt4all, line, exec.getContext(), scope.getConversationId(), Collections.emptyList(), promptConfiguration)); + } + public CompletableFuture contextScope(Function1 store, Function1> f) { return exec.future(continuation -> scope.contextScope(store.invoke(scope.getEmbeddings()), (coreAIScope, continuation1) -> { AIScope nestedScope = new AIScope(coreAIScope, AIScope.this); @@ -106,7 +119,6 @@ public CompletableFuture contextScope(Function1 }, continuation)); } - public CompletableFuture contextScope(VectorStore store, Function1> f) { return exec.future(continuation -> scope.contextScope(store, (coreAIScope, continuation1) -> { AIScope nestedScope = new AIScope(coreAIScope, AIScope.this); diff --git a/java/src/main/java/com/xebia/functional/xef/java/auto/ExecutionContext.java b/java/src/main/java/com/xebia/functional/xef/java/auto/ExecutionContext.java index d5e4af80c..fad9ee4fa 100644 --- a/java/src/main/java/com/xebia/functional/xef/java/auto/ExecutionContext.java +++ b/java/src/main/java/com/xebia/functional/xef/java/auto/ExecutionContext.java @@ -6,23 +6,27 @@ import com.xebia.functional.xef.embeddings.Embeddings; import com.xebia.functional.xef.vectorstores.LocalVectorStore; import com.xebia.functional.xef.vectorstores.VectorStore; -import kotlin.coroutines.Continuation; -import kotlin.jvm.functions.Function1; -import kotlinx.coroutines.*; -import kotlinx.coroutines.future.FutureKt; -import org.jetbrains.annotations.NotNull; - import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.ThreadFactory; import java.util.concurrent.atomic.AtomicInteger; +import kotlin.coroutines.Continuation; +import kotlin.jvm.functions.Function1; +import kotlinx.coroutines.CoroutineScope; +import kotlinx.coroutines.CoroutineScopeKt; +import kotlinx.coroutines.CoroutineStart; +import kotlinx.coroutines.ExecutorsKt; +import kotlinx.coroutines.JobKt; +import kotlinx.coroutines.future.FutureKt; +import org.jetbrains.annotations.NotNull; public class ExecutionContext implements AutoCloseable { private final ExecutorService executorService; private final CoroutineScope coroutineScope; private final CoreAIScope scope; + private final VectorStore context; public ExecutionContext(){ this(Executors.newCachedThreadPool(new ExecutionContext.AIScopeThreadFactory()), new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING)); @@ -31,8 +35,8 @@ public ExecutionContext(){ public ExecutionContext(ExecutorService executorService, Embeddings embeddings) { this.executorService = executorService; this.coroutineScope = () -> ExecutorsKt.from(executorService).plus(JobKt.Job(null)); - VectorStore vectorStore = new LocalVectorStore(embeddings); - this.scope = new CoreAIScope(embeddings, vectorStore); + context = new LocalVectorStore(embeddings); + this.scope = new CoreAIScope(embeddings, context); } protected CompletableFuture future(Function1, ? extends Object> block) { @@ -44,6 +48,10 @@ protected CompletableFuture future(Function1