Skip to content

Commit

Permalink
Java example gpt4all (#253)
Browse files Browse the repository at this point in the history
* Java Examples web search

* Add Search to AIScope

* Add Search to AIScope

* Refactor

* Refactor

* Add Chat

* Refactor

* Start java example gpt4all

* Add flow collector

* Add port to test

* Finish java example gpt4all

* Refactor

* Add Publisher

* Refactor

* Refactor

* Refactor

* Refactor

* Final Refactor

* Final Refactor conversationId
  • Loading branch information
Zevleg authored Jul 24, 2023
1 parent e98a7d4 commit dc0aa2f
Show file tree
Hide file tree
Showing 8 changed files with 134 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.xebia.functional.xef.vectorstores.CombinedVectorStore
import com.xebia.functional.xef.vectorstores.ConversationId
import com.xebia.functional.xef.vectorstores.LocalVectorStore
import com.xebia.functional.xef.vectorstores.VectorStore
import kotlinx.coroutines.flow.Flow
import kotlin.jvm.JvmName
import kotlin.jvm.JvmOverloads
import kotlinx.uuid.UUID
Expand Down Expand Up @@ -130,6 +131,16 @@ constructor(
): List<String> =
promptMessages(Prompt(question), context, conversationId, functions, promptConfiguration)

@AiDsl
fun Chat.promptStreaming(
question: String,
context: VectorStore,
conversationId: ConversationId?,
functions: List<CFunction>,
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): Flow<String> =
promptStreaming(Prompt(question), context, conversationId, functions, promptConfiguration)

/**
* Run a [prompt] describes the images you want to generate within the context of [CoreAIScope].
* Returns a [ImagesGenerationResponse] containing time and urls with images generated.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ interface Chat : LLM {
fun tokensFromMessages(messages: List<Message>): Int

@AiDsl
suspend fun promptStreaming(
fun promptStreaming(
question: String,
context: VectorStore,
conversationId: ConversationId? = null,
Expand All @@ -37,13 +37,13 @@ interface Chat : LLM {
promptStreaming(Prompt(question), context, conversationId, functions, promptConfiguration)

@AiDsl
suspend fun promptStreaming(
fun promptStreaming(
prompt: Prompt,
context: VectorStore,
conversationId: ConversationId? = null,
functions: List<CFunction> = emptyList(),
promptConfiguration: PromptConfiguration = PromptConfiguration.DEFAULTS
): Flow<String> {
): Flow<String> = flow {

val memories: List<Memory> = memories(conversationId, context, promptConfiguration)

Expand Down Expand Up @@ -79,7 +79,6 @@ interface Chat : LLM {
streamToStandardOut = true
)

return flow {
val buffer = StringBuilder()
createChatCompletions(request)
.onEach {
Expand All @@ -90,7 +89,6 @@ interface Chat : LLM {
}
.onCompletion { addMemoriesAfterStream(request, conversationId, buffer, context) }
.collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) }
}
}

private suspend fun addMemoriesAfterStream(
Expand Down
1 change: 1 addition & 0 deletions examples/java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ plugins {

dependencies {
implementation(projects.xefJava)
implementation(projects.xefGpt4all)
}

tasks.withType<Test>().configureEach {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.xebia.functional.xef.java.auto.gpt4all;

import com.xebia.functional.gpt4all.GPT4All;
import com.xebia.functional.gpt4all.Gpt4AllModel;
import com.xebia.functional.xef.auto.PromptConfiguration;
import com.xebia.functional.xef.java.auto.AIScope;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.file.Path;
import java.util.Objects;
import java.util.concurrent.ExecutionException;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscriber;
import org.reactivestreams.Subscription;

public class Chat {
public static void main(String[] args) throws ExecutionException, InterruptedException, IOException {
var userDir = System.getProperty("user.dir");
var path = userDir + "/models/gpt4all/ggml-replit-code-v1-3b.bin";

var supportedModels = Gpt4AllModel.Companion.getSupportedModels();

supportedModels.forEach(it -> {
var url = (Objects.nonNull(it.getUrl())) ? " - " + it.getUrl() : "";
System.out.println("🤖 " + it.getName() + url);
});

var url = "https://huggingface.co/nomic-ai/ggml-replit-code-v1-3b/resolve/main/ggml-replit-code-v1-3b.bin";
var modelPath = Path.of(path);
var gpt4all = GPT4All.Companion.invoke(url, modelPath);

System.out.println("🤖 GPT4All loaded: " + gpt4all);
/**
* Uses internally [HuggingFaceLocalEmbeddings] default of "sentence-transformers", "msmarco-distilbert-dot-v5"
* to provide embeddings for docs in contextScope.
*/

try (AIScope scope = new AIScope();
BufferedReader br = new BufferedReader(new InputStreamReader(System.in))) {

System.out.println("🤖 Context loaded: " + scope.getExec().getContext());

System.out.println("\n🤖 Enter your question: ");

while(true){
String line = br.readLine();
if (line.equals("exit")) break;

var promptConfiguration = new PromptConfiguration.Companion.Builder().docsInContext(2).streamToStandardOut(true).build();
Publisher<String> answer = scope.promptStreaming(gpt4all, line, promptConfiguration);

answer.subscribe(new Subscriber<String>() {
StringBuilder answer = new StringBuilder();

@Override
public void onSubscribe(Subscription s) {
System.out.print("\n🤖 --> " + s);
s.request(Long.MAX_VALUE);
}

@Override
public void onNext(String s) {
answer.append(s);
}

@Override
public void onError(Throwable t) {
System.out.println(t);
}

@Override
public void onComplete() {
System.out.println("\n🤖 --> " + answer.toString());
System.out.println("\n🤖 --> Done");
System.out.println("\n🤖 Enter your question: ");
}
});
}
}
}
}
2 changes: 2 additions & 0 deletions gradle/libs.versions.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ hikari = "5.0.1"
dokka = "1.8.20"
logback = "1.4.8"
kotlinx-coroutines = "1.7.2"
kotlinx-coroutines-reactive = "1.7.2"
scalaMultiversion = "2.0.4"
circe = "0.14.5"
catsEffect = "3.6-0142603"
Expand Down Expand Up @@ -45,6 +46,7 @@ arrow-fx-coroutines = { module = "io.arrow-kt:arrow-fx-coroutines", version.ref
open-ai = { module = "com.theokanning.openai-gpt3-java:service", version.ref = "openai" }
kotlinx-serialization-json = { module = "org.jetbrains.kotlinx:kotlinx-serialization-json", version.ref = "kotlinx-json" }
kotlinx-coroutines = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-core", version.ref="kotlinx-coroutines" }
kotlinx-coroutines-reactive = { module = "org.jetbrains.kotlinx:kotlinx-coroutines-reactive", version.ref="kotlinx-coroutines-reactive" }
ktor-utils = { module = "io.ktor:ktor-utils", version.ref = "ktor" }
ktor-client ={ module = "io.ktor:ktor-client-core", version.ref = "ktor" }
ktor-client-cio = { module = "io.ktor:ktor-client-cio", version.ref = "ktor" }
Expand Down
1 change: 1 addition & 0 deletions java/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies {
api(libs.jackson.schema)
api(libs.jackson.schema.jakarta)
api(libs.jakarta.validation)
api(libs.kotlinx.coroutines.reactive)
}

java {
Expand Down
24 changes: 18 additions & 6 deletions java/src/main/java/com/xebia/functional/xef/java/auto/AIScope.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.github.victools.jsonschema.generator.*;
import com.github.victools.jsonschema.generator.OptionPreset;
import com.github.victools.jsonschema.generator.SchemaGenerator;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfig;
import com.github.victools.jsonschema.generator.SchemaGeneratorConfigBuilder;
import com.github.victools.jsonschema.generator.SchemaVersion;
import com.github.victools.jsonschema.module.jakarta.validation.JakartaValidationModule;
import com.github.victools.jsonschema.module.jakarta.validation.JakartaValidationOption;
import com.xebia.functional.xef.agents.Search;
Expand All @@ -20,14 +24,15 @@
import com.xebia.functional.xef.sql.SQL;
import com.xebia.functional.xef.textsplitters.TextSplitter;
import com.xebia.functional.xef.vectorstores.VectorStore;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlinx.coroutines.future.FutureKt;

import java.io.File;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import kotlin.collections.CollectionsKt;
import kotlin.jvm.functions.Function1;
import kotlinx.coroutines.future.FutureKt;
import kotlinx.coroutines.reactive.ReactiveFlowKt;
import org.reactivestreams.Publisher;

public class AIScope implements AutoCloseable {
private final CoreAIScope scope;
Expand All @@ -49,6 +54,10 @@ public AIScope(ObjectMapper om, ExecutionContext executionContext) {
this.scope = executionContext.getCoreScope();
}

public ExecutionContext getExec() {
return exec;
}

public AIScope(ExecutionContext executionContext) {
this(new ObjectMapper(), executionContext);
}
Expand Down Expand Up @@ -99,14 +108,17 @@ public CompletableFuture<List<String>> promptMessages(Chat llmModel, String prom
return exec.future(continuation -> scope.promptMessages(llmModel, prompt, functions, promptConfiguration, continuation));
}

public Publisher<String> promptStreaming(Chat gpt4all, String line, PromptConfiguration promptConfiguration) {
return ReactiveFlowKt.asPublisher(scope.promptStreaming(gpt4all, line, exec.getContext(), scope.getConversationId(), Collections.emptyList(), promptConfiguration));
}

public <A> CompletableFuture<A> contextScope(Function1<Embeddings, VectorStore> store, Function1<AIScope, CompletableFuture<A>> f) {
return exec.future(continuation -> scope.contextScope(store.invoke(scope.getEmbeddings()), (coreAIScope, continuation1) -> {
AIScope nestedScope = new AIScope(coreAIScope, AIScope.this);
return FutureKt.await(f.invoke(nestedScope), continuation);
}, continuation));
}


public <A> CompletableFuture<A> contextScope(VectorStore store, Function1<AIScope, CompletableFuture<A>> f) {
return exec.future(continuation -> scope.contextScope(store, (coreAIScope, continuation1) -> {
AIScope nestedScope = new AIScope(coreAIScope, AIScope.this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,27 @@
import com.xebia.functional.xef.embeddings.Embeddings;
import com.xebia.functional.xef.vectorstores.LocalVectorStore;
import com.xebia.functional.xef.vectorstores.VectorStore;
import kotlin.coroutines.Continuation;
import kotlin.jvm.functions.Function1;
import kotlinx.coroutines.*;
import kotlinx.coroutines.future.FutureKt;
import org.jetbrains.annotations.NotNull;

import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import kotlin.coroutines.Continuation;
import kotlin.jvm.functions.Function1;
import kotlinx.coroutines.CoroutineScope;
import kotlinx.coroutines.CoroutineScopeKt;
import kotlinx.coroutines.CoroutineStart;
import kotlinx.coroutines.ExecutorsKt;
import kotlinx.coroutines.JobKt;
import kotlinx.coroutines.future.FutureKt;
import org.jetbrains.annotations.NotNull;

public class ExecutionContext implements AutoCloseable {

private final ExecutorService executorService;
private final CoroutineScope coroutineScope;
private final CoreAIScope scope;
private final VectorStore context;

public ExecutionContext(){
this(Executors.newCachedThreadPool(new ExecutionContext.AIScopeThreadFactory()), new OpenAIEmbeddings(OpenAI.DEFAULT_EMBEDDING));
Expand All @@ -31,8 +35,8 @@ public ExecutionContext(){
public ExecutionContext(ExecutorService executorService, Embeddings embeddings) {
this.executorService = executorService;
this.coroutineScope = () -> ExecutorsKt.from(executorService).plus(JobKt.Job(null));
VectorStore vectorStore = new LocalVectorStore(embeddings);
this.scope = new CoreAIScope(embeddings, vectorStore);
context = new LocalVectorStore(embeddings);
this.scope = new CoreAIScope(embeddings, context);
}

protected <A> CompletableFuture<A> future(Function1<? super Continuation<? super A>, ? extends Object> block) {
Expand All @@ -44,6 +48,10 @@ protected <A> CompletableFuture<A> future(Function1<? super Continuation<? super
);
}

public VectorStore getContext() {
return context;
}

@Override
public void close() {
CoroutineScopeKt.cancel(coroutineScope, null);
Expand Down

0 comments on commit dc0aa2f

Please sign in to comment.