From e880af260fe0467f135bc2723f2c39295df0b4c3 Mon Sep 17 00:00:00 2001 From: ron Date: Mon, 4 Sep 2023 11:10:19 +0200 Subject: [PATCH 1/3] improve and fix local llm generation as flow --- .../com/xebia/functional/xef/llm/Chat.kt | 20 ++--- .../com/xebia/functional/gpt4all/GPT4All.kt | 83 +++++++++---------- 2 files changed, 45 insertions(+), 58 deletions(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index d0cbd6083..8f9cefe6c 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -7,10 +7,7 @@ import com.xebia.functional.xef.conversation.Conversation import com.xebia.functional.xef.llm.models.chat.* import com.xebia.functional.xef.prompt.Prompt import com.xebia.functional.xef.prompt.templates.assistant -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.flow -import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.flow.onEach +import kotlinx.coroutines.flow.* interface Chat : LLM { val modelType: ModelType @@ -37,19 +34,14 @@ interface Chat : LLM { streamToStandardOut = true ) - val buffer = StringBuilder() createChatCompletions(request) - .onEach { - it.choices.forEach { choice -> - val text = choice.delta?.content ?: "" - buffer.append(text) - } - } - .onCompletion { - val message = assistant(buffer.toString()) + .map { it.choices.mapNotNull { it.delta?.content }.reduce(String::plus) } + .onEach { emit(it) } + .fold("", String::plus) + .also { finalText -> + val message = assistant(finalText) MemoryManagement.addMemoriesAfterStream(this@Chat, request, scope, listOf(message)) } - .collect { emit(it.choices.mapNotNull { it.delta?.content }.joinToString("")) } } @AiDsl diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt index ce1aeef2a..da3284fc6 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt @@ -16,21 +16,16 @@ import com.xebia.functional.xef.llm.models.text.CompletionResult import com.xebia.functional.xef.llm.models.usage.Usage import com.xebia.functional.xef.store.LocalVectorStore import com.xebia.functional.xef.store.VectorStore -import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED -import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.flow.Flow -import kotlinx.coroutines.flow.consumeAsFlow -import kotlinx.coroutines.flow.map -import kotlinx.coroutines.flow.onCompletion -import kotlinx.coroutines.launch +import kotlinx.coroutines.flow.* import java.io.OutputStream import java.io.PrintStream import java.nio.charset.StandardCharsets import java.nio.file.Files import java.nio.file.Path -import java.util.UUID +import java.util.* +import kotlin.concurrent.thread import kotlin.io.path.name @@ -73,7 +68,6 @@ interface GPT4All : AutoCloseable, Chat, Completion { override suspend fun createCompletion(request: CompletionRequest): CompletionResult = with(request) { - val config = LLModel.config() .withTopP(request.topP?.toFloat() ?: 0.4f) .withTemp(request.temperature?.toFloat() ?: 0f) @@ -94,8 +88,8 @@ interface GPT4All : AutoCloseable, Chat, Completion { with(request) { val prompt: String = messages.buildPrompt() val config = LLModel.config() - .withTopP(request.topP.toFloat() ?: 0.4f) - .withTemp(request.temperature.toFloat() ?: 0f) + .withTopP(request.topP.toFloat()) + .withTemp(request.temperature.toFloat()) .withRepeatPenalty(request.frequencyPenalty.toFloat()) .build() val response: String = generateCompletion(prompt, config, request.streamToStandardOut) @@ -117,51 +111,52 @@ interface GPT4All : AutoCloseable, Chat, Completion { * @param request The ChatCompletionRequest containing the necessary information for creating completions. * @return A Flow of ChatCompletionChunk objects representing the generated chat completions. */ - override suspend fun createChatCompletions(request: ChatCompletionRequest): Flow = - with(request) { - val prompt: String = messages.buildPrompt() - val config = LLModel.config() - .withTopP(request.topP.toFloat()) - .withTemp(request.temperature.toFloat()) - .withRepeatPenalty(request.frequencyPenalty.toFloat()) + override suspend fun createChatCompletions(request: ChatCompletionRequest): Flow { + val prompt: String = request.messages.buildPrompt() + val config = with(request) { + LLModel.config() + .withTopP(topP.toFloat()) + .withTemp(temperature.toFloat()) + .withRepeatPenalty(frequencyPenalty.toFloat()) .build() + } - val originalOut = System.out // Save the original standard output + val originalOut = System.out // Save the original standard output - return coroutineScope { - val channel = Channel(capacity = UNLIMITED) + val channel = Channel(capacity = UNLIMITED) - val outputStream = object : OutputStream() { - override fun write(b: Int) { - val c = b.toChar() - channel.trySend(c.toString()) - } - } + val outputStream = object : OutputStream() { + override fun write(b: Int) { + val c = b.toChar() + channel.trySend(c.toString()) + } + } - val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8) + val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8) - fun toChunk(text: String?): ChatCompletionChunk = - ChatCompletionChunk( - UUID.randomUUID().toString(), - System.currentTimeMillis().toInt(), - path.name, - listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))), - Usage.ZERO, - ) + fun toChunk(text: String?): ChatCompletionChunk = ChatCompletionChunk( + UUID.randomUUID().toString(), + System.currentTimeMillis().toInt(), + path.name, + listOf(ChatChunk(delta = ChatDelta(Role.ASSISTANT, text))), + Usage.ZERO, + ) - val flow = channel.consumeAsFlow().map { toChunk(it) } + return channel + .consumeAsFlow() + .map(::toChunk) + .onStart { + System.setOut(printStream) // Set the standard output to the print stream - launch(Dispatchers.IO) { - System.setOut(printStream) // Set the standard output to the print stream + thread(isDaemon = true) { // generate in background and emit values to flow generateCompletion(prompt, config, request.streamToStandardOut) channel.close() } - - flow.onCompletion { - System.setOut(originalOut) // Restore the original standard output - } } - } + .onCompletion { + System.setOut(originalOut) // Restore the original standard output + } + } override fun tokensFromMessages(messages: List): Int { return 0 From 3a166fd0b201f5ef0b00d2f5075b001bcf41a070 Mon Sep 17 00:00:00 2001 From: ron Date: Tue, 5 Sep 2023 12:18:00 +0200 Subject: [PATCH 2/3] change reduce to reduceNotNull to avoid exception --- core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt index 8f9cefe6c..181706f7f 100644 --- a/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt +++ b/core/src/commonMain/kotlin/com/xebia/functional/xef/llm/Chat.kt @@ -35,7 +35,7 @@ interface Chat : LLM { ) createChatCompletions(request) - .map { it.choices.mapNotNull { it.delta?.content }.reduce(String::plus) } + .mapNotNull { it.choices.mapNotNull { it.delta?.content }.reduceOrNull(String::plus) } .onEach { emit(it) } .fold("", String::plus) .also { finalText -> From 9e2f8f2bef07cb0f9bdca3c2eaa4599e51f16208 Mon Sep 17 00:00:00 2001 From: ron Date: Tue, 5 Sep 2023 15:15:35 +0200 Subject: [PATCH 3/3] addressing comment regarding thread instantiation --- .../com/xebia/functional/gpt4all/GPT4All.kt | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt index da3284fc6..8462c9126 100644 --- a/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt +++ b/gpt4all-kotlin/src/jvmMain/kotlin/com/xebia/functional/gpt4all/GPT4All.kt @@ -16,6 +16,7 @@ import com.xebia.functional.xef.llm.models.text.CompletionResult import com.xebia.functional.xef.llm.models.usage.Usage import com.xebia.functional.xef.store.LocalVectorStore import com.xebia.functional.xef.store.VectorStore +import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.channels.Channel import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED import kotlinx.coroutines.flow.* @@ -25,7 +26,6 @@ import java.nio.charset.StandardCharsets import java.nio.file.Files import java.nio.file.Path import java.util.* -import kotlin.concurrent.thread import kotlin.io.path.name @@ -121,10 +121,7 @@ interface GPT4All : AutoCloseable, Chat, Completion { .build() } - val originalOut = System.out // Save the original standard output - val channel = Channel(capacity = UNLIMITED) - val outputStream = object : OutputStream() { override fun write(b: Int) { val c = b.toChar() @@ -132,7 +129,7 @@ interface GPT4All : AutoCloseable, Chat, Completion { } } - val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8) + val originalOut = System.out // Save the original standard output fun toChunk(text: String?): ChatCompletionChunk = ChatCompletionChunk( UUID.randomUUID().toString(), @@ -142,20 +139,20 @@ interface GPT4All : AutoCloseable, Chat, Completion { Usage.ZERO, ) - return channel - .consumeAsFlow() - .map(::toChunk) - .onStart { - System.setOut(printStream) // Set the standard output to the print stream - - thread(isDaemon = true) { // generate in background and emit values to flow + return merge( + emptyFlow() + .onStart { + val printStream = PrintStream(outputStream, true, StandardCharsets.UTF_8) + System.setOut(printStream) // Set the standard output to the print stream generateCompletion(prompt, config, request.streamToStandardOut) channel.close() + System.setOut(originalOut) // Restore the original standard output } - } - .onCompletion { - System.setOut(originalOut) // Restore the original standard output - } + .flowOn(Dispatchers.IO), + channel + .consumeAsFlow() + .map(::toChunk) + ) } override fun tokensFromMessages(messages: List): Int {