Skip to content

Commit

Permalink
Java: Add GeminiTextGenerationService and usage data to chat completi…
Browse files Browse the repository at this point in the history
…on (#6583)

### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

Fixes #6580

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
milderhc committed Jun 11, 2024
1 parent 7579ccc commit 9f409dc
Show file tree
Hide file tree
Showing 10 changed files with 191 additions and 207 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
import com.microsoft.semantickernel.aiservices.google.implementation.MonoConverter;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.orchestration.FunctionResult;
import com.microsoft.semantickernel.exceptions.SKCheckedException;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
Expand All @@ -40,8 +42,10 @@

import javax.annotation.Nullable;
import java.io.IOException;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;

public class GeminiChatCompletion extends GeminiService implements ChatCompletionService {
Expand Down Expand Up @@ -87,11 +91,13 @@ public Mono<List<ChatMessageContent<?>>> getChatMessageContentsAsync(ChatHistory
private Mono<List<ChatMessageContent<?>>> internalChatMessageContentsAsync(
ChatHistory fullHistory, ChatHistory newHistory, @Nullable Kernel kernel,
@Nullable InvocationContext invocationContext, int invocationAttempts) {

List<Content> contents = getContents(fullHistory);
GenerativeModel model = getGenerativeModel(kernel, invocationContext);

try {
GenerativeModel model = getGenerativeModel(kernel, invocationContext);
return MonoConverter.fromApiFuture(model.generateContentAsync(contents))
.doOnError(e -> LOGGER.error("Error generating chat completion", e))
.flatMap(result -> {
// Get ChatMessageContent from the response
GeminiChatMessageContent<?> response = getGeminiChatMessageContentFromResponse(
Expand Down Expand Up @@ -144,8 +150,8 @@ private Mono<List<ChatMessageContent<?>>> internalChatMessageContentsAsync(
invocationContext, invocationAttempts - 1);
});
});
} catch (IOException e) {
throw new RuntimeException(e);
} catch (SKCheckedException | IOException e) {
return Mono.error(new SKException("Error generating chat completion", e));
}
}

Expand Down Expand Up @@ -223,12 +229,15 @@ private GeminiChatMessageContent<?> getGeminiChatMessageContentFromResponse(
});
});

FunctionResultMetadata<GenerateContentResponse.UsageMetadata> metadata = FunctionResultMetadata
.build(UUID.randomUUID().toString(), response.getUsageMetadata(), OffsetDateTime.now());

return new GeminiChatMessageContent<>(AuthorRole.ASSISTANT,
message.toString(), null, null, null, null, functionCalls);
message.toString(), null, null, null, metadata, functionCalls);
}

private GenerativeModel getGenerativeModel(@Nullable Kernel kernel,
@Nullable InvocationContext invocationContext) {
@Nullable InvocationContext invocationContext) throws SKCheckedException {
GenerativeModel.Builder modelBuilder = new GenerativeModel.Builder()
.setModelName(getModelId())
.setVertexAi(getClient());
Expand All @@ -239,10 +248,11 @@ private GenerativeModel getGenerativeModel(@Nullable Kernel kernel,

if (settings.getResultsPerPrompt() < 1
|| settings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT));
throw SKCheckedException.build("Error building generative model.",
new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT)));
}

GenerationConfig config = GenerationConfig.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.aiservices.google.textcompletion;

import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.google.GeminiService;
import com.microsoft.semantickernel.aiservices.google.implementation.MonoConverter;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.exceptions.SKCheckedException;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.services.gemini.GeminiServiceBuilder;
import com.microsoft.semantickernel.services.textcompletion.StreamingTextContent;
import com.microsoft.semantickernel.services.textcompletion.TextContent;
import com.microsoft.semantickernel.services.textcompletion.TextGenerationService;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.annotation.Nullable;
import java.io.IOException;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

public class GeminiTextGenerationService extends GeminiService implements TextGenerationService {
private static final Logger LOGGER = LoggerFactory.getLogger(GeminiTextGenerationService.class);

public GeminiTextGenerationService(VertexAI client, String modelId) {
super(client, modelId);
}

public static Builder builder() {
return new Builder();
}

@Override
public Mono<List<TextContent>> getTextContentsAsync(
String prompt,
@Nullable PromptExecutionSettings executionSettings,
@Nullable Kernel kernel) {
return this.internalGetTextAsync(prompt, executionSettings);
}

@Override
public Flux<StreamingTextContent> getStreamingTextContentsAsync(
String prompt,
@Nullable PromptExecutionSettings executionSettings,
@Nullable Kernel kernel) {
return this
.internalGetTextAsync(prompt, executionSettings)
.flatMapMany(it -> Flux.fromStream(it.stream())
.map(StreamingTextContent::new));
}

private Mono<List<TextContent>> internalGetTextAsync(String prompt,
@Nullable PromptExecutionSettings executionSettings) {

try {
GenerativeModel model = getGenerativeModel(executionSettings);
return MonoConverter.fromApiFuture(model.generateContentAsync(prompt))
.doOnError(e -> LOGGER.error("Error generating text", e))
.flatMap(result -> {
List<TextContent> textContents = new ArrayList<>();

FunctionResultMetadata<GenerateContentResponse.UsageMetadata> metadata = FunctionResultMetadata
.build(
UUID.randomUUID().toString(),
result.getUsageMetadata(),
OffsetDateTime.now());

result.getCandidatesList().forEach(
candidate -> {
candidate.getContent().getPartsList().forEach(part -> {
if (!part.getText().isEmpty()) {
textContents.add(
new TextContent(part.getText(), getModelId(), metadata));
}
});
});

return Mono.just(textContents);
});
} catch (SKCheckedException | IOException e) {
return Mono.error(new SKException("Error generating text", e));
}
}

private GenerativeModel getGenerativeModel(
@Nullable PromptExecutionSettings executionSettings) throws SKCheckedException {
GenerativeModel.Builder modelBuilder = new GenerativeModel.Builder()
.setModelName(getModelId())
.setVertexAi(getClient());

if (executionSettings != null) {
if (executionSettings.getResultsPerPrompt() < 1
|| executionSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw SKCheckedException.build("Error building generative model.",
new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT)));
}

GenerationConfig config = GenerationConfig.newBuilder()
.setMaxOutputTokens(executionSettings.getMaxTokens())
.setTemperature((float) executionSettings.getTemperature())
.setTopP((float) executionSettings.getTopP())
.setCandidateCount(executionSettings.getResultsPerPrompt())
.build();

modelBuilder.setGenerationConfig(config);
}

return modelBuilder.build();
}

public static class Builder extends
GeminiServiceBuilder<GeminiTextGenerationService, GeminiTextGenerationService.Builder> {
@Override
public GeminiTextGenerationService build() {
if (this.client == null) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"VertexAI client must be provided");
}

if (this.modelId == null || modelId.isEmpty()) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"Gemini model id must be provided");
}

return new GeminiTextGenerationService(client, modelId);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.azure.ai.openai.models.ChatRequestToolMessage;
import com.azure.ai.openai.models.ChatRequestUserMessage;
import com.azure.ai.openai.models.ChatResponseMessage;
import com.azure.ai.openai.models.CompletionsUsage;
import com.azure.ai.openai.models.FunctionCall;
import com.azure.core.util.BinaryData;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand Down Expand Up @@ -512,7 +513,7 @@ private OpenAIFunctionToolCall extractOpenAIFunctionToolCall(

private Mono<List<OpenAIChatMessageContent>> getChatMessageContentsAsync(
ChatCompletions completions) {
FunctionResultMetadata completionMetadata = FunctionResultMetadata.build(
FunctionResultMetadata<CompletionsUsage> completionMetadata = FunctionResultMetadata.build(
completions.getId(),
completions.getUsage(),
completions.getCreatedAt());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.models.CompletionsOptions;
import com.azure.ai.openai.models.CompletionsUsage;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.aiservices.openai.OpenAiService;
import com.microsoft.semantickernel.aiservices.openai.implementation.OpenAIRequestSettings;
Expand Down Expand Up @@ -91,7 +92,7 @@ protected Mono<List<TextContent>> internalCompleteTextAsync(
return Mono.just(completionsResult.getValue());
})
.map(completions -> {
FunctionResultMetadata metadata = FunctionResultMetadata.build(
FunctionResultMetadata<CompletionsUsage> metadata = FunctionResultMetadata.build(
completions.getId(),
completions.getUsage(),
completions.getCreatedAt());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import com.azure.ai.openai.OpenAIAsyncClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.ai.openai.models.CompletionsUsage;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.KeyCredential;
import com.microsoft.semantickernel.Kernel;
Expand Down Expand Up @@ -68,9 +69,9 @@ public static void main(String[] args) {
// Display results
System.out.println(result.getResult());
System.out.println(
"Usage: " + result
"Usage: " + ((CompletionsUsage) result
.getMetadata()
.getUsage().getTotalTokens());
.getUsage()).getTotalTokens());
System.out.println();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import com.azure.ai.openai.models.ChatCompletionsOptions;
import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.ChatRequestSystemMessage;
import com.azure.ai.openai.models.CompletionsUsage;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.core.credential.KeyCredential;
import com.microsoft.semantickernel.Kernel;
Expand Down Expand Up @@ -111,10 +112,10 @@ private static void getUsageAsync(Kernel kernel) {

FunctionInvokedHook postExecutionHandler = event -> {
System.out.println(
event.getFunction().getName() + " : Post Execution Handler - Usage: " + event
event.getFunction().getName() + " : Post Execution Handler - Usage: " + ((CompletionsUsage) event
.getResult()
.getMetadata()
.getUsage()
.getUsage())
.getTotalTokens());
return event;
};
Expand Down
Loading

0 comments on commit 9f409dc

Please sign in to comment.