Skip to content

Commit

Permalink
Add suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
milderhc committed Jun 7, 2024
1 parent 9651244 commit 9a1c6e3
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import com.microsoft.semantickernel.aiservices.google.implementation.MonoConverter;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypes;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.exceptions.SKCheckedException;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.InvocationReturnMode;
Expand All @@ -40,8 +42,10 @@

import javax.annotation.Nullable;
import java.io.IOException;
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;

public class GeminiChatCompletion extends GeminiService implements ChatCompletionService {
Expand Down Expand Up @@ -87,11 +91,13 @@ public Mono<List<ChatMessageContent<?>>> getChatMessageContentsAsync(ChatHistory
private Mono<List<ChatMessageContent<?>>> internalChatMessageContentsAsync(
ChatHistory fullHistory, ChatHistory newHistory, @Nullable Kernel kernel,
@Nullable InvocationContext invocationContext, int invocationAttempts) {

List<Content> contents = getContents(fullHistory);
GenerativeModel model = getGenerativeModel(kernel, invocationContext);

try {
GenerativeModel model = getGenerativeModel(kernel, invocationContext);
return MonoConverter.fromApiFuture(model.generateContentAsync(contents))
.doOnError(e -> LOGGER.error("Error generating chat completion", e))
.flatMap(result -> {
// Get ChatMessageContent from the response
GeminiChatMessageContent<?> response = getGeminiChatMessageContentFromResponse(
Expand Down Expand Up @@ -144,8 +150,8 @@ private Mono<List<ChatMessageContent<?>>> internalChatMessageContentsAsync(
invocationContext, invocationAttempts - 1);
});
});
} catch (IOException e) {
throw new RuntimeException(e);
} catch (SKCheckedException | IOException e) {
return Mono.error(new SKException("Error generating chat completion", e));
}
}

Expand Down Expand Up @@ -224,14 +230,14 @@ private GeminiChatMessageContent<?> getGeminiChatMessageContentFromResponse(
});

FunctionResultMetadata<GenerateContentResponse.UsageMetadata> metadata = FunctionResultMetadata
.build(null, response.getUsageMetadata(), null);
.build(UUID.randomUUID().toString(), response.getUsageMetadata(), OffsetDateTime.now());

return new GeminiChatMessageContent<>(AuthorRole.ASSISTANT,
message.toString(), null, null, null, metadata, functionCalls);
}

private GenerativeModel getGenerativeModel(@Nullable Kernel kernel,
@Nullable InvocationContext invocationContext) {
@Nullable InvocationContext invocationContext) throws SKCheckedException {
GenerativeModel.Builder modelBuilder = new GenerativeModel.Builder()
.setModelName(getModelId())
.setVertexAi(getClient());
Expand All @@ -242,10 +248,11 @@ private GenerativeModel getGenerativeModel(@Nullable Kernel kernel,

if (settings.getResultsPerPrompt() < 1
|| settings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT));
throw SKCheckedException.build("Error building generative model.",
new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT)));
}

GenerationConfig config = GenerationConfig.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import com.microsoft.semantickernel.aiservices.google.GeminiService;
import com.microsoft.semantickernel.aiservices.google.implementation.MonoConverter;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.exceptions.SKCheckedException;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.FunctionResultMetadata;
import com.microsoft.semantickernel.orchestration.PromptExecutionSettings;
import com.microsoft.semantickernel.services.gemini.GeminiServiceBuilder;
Expand All @@ -26,6 +28,7 @@
import java.time.OffsetDateTime;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;

public class GeminiTextGenerationService extends GeminiService implements TextGenerationService {
private static final Logger LOGGER = LoggerFactory.getLogger(GeminiTextGenerationService.class);
Expand Down Expand Up @@ -59,19 +62,19 @@ public Flux<StreamingTextContent> getStreamingTextContentsAsync(

private Mono<List<TextContent>> internalGetTextAsync(String prompt,
@Nullable PromptExecutionSettings executionSettings) {
GenerativeModel model = getGenerativeModel(executionSettings);

try {
GenerativeModel model = getGenerativeModel(executionSettings);
return MonoConverter.fromApiFuture(model.generateContentAsync(prompt))
.doOnError(e -> LOGGER.error("Error generating text", e))
.flatMap(result -> {
List<TextContent> textContents = new ArrayList<>();

FunctionResultMetadata<GenerateContentResponse.UsageMetadata> metadata = FunctionResultMetadata
.build(
null,
UUID.randomUUID().toString(),
result.getUsageMetadata(),
null);
OffsetDateTime.now());

result.getCandidatesList().forEach(
candidate -> {
Expand All @@ -85,24 +88,25 @@ private Mono<List<TextContent>> internalGetTextAsync(String prompt,

return Mono.just(textContents);
});
} catch (IOException e) {
throw new RuntimeException(e);
} catch (SKCheckedException | IOException e) {
return Mono.error(new SKException("Error generating text", e));
}
}

private GenerativeModel getGenerativeModel(
@Nullable PromptExecutionSettings executionSettings) {
@Nullable PromptExecutionSettings executionSettings) throws SKCheckedException {
GenerativeModel.Builder modelBuilder = new GenerativeModel.Builder()
.setModelName(getModelId())
.setVertexAi(getClient());

if (executionSettings != null) {
if (executionSettings.getResultsPerPrompt() < 1
|| executionSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT));
throw SKCheckedException.build("Error building generative model.",
new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format(
"Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT)));
}

GenerationConfig config = GenerationConfig.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ public FunctionResultMetadata(CaseInsensitiveMap<ContextVariable<?>> metadata) {
/**
* Create a new instance of FunctionResultMetadata.
*/
public static FunctionResultMetadata<?> build(@Nullable String id) {
public static FunctionResultMetadata<?> build(String id) {
return build(id, null, null);
}

Expand All @@ -64,15 +64,14 @@ public static FunctionResultMetadata<?> build(@Nullable String id) {
* @return A new instance of FunctionResultMetadata.
*/
public static <UsageType> FunctionResultMetadata<UsageType> build(
@Nullable String id,
String id,
@Nullable UsageType usage,
@Nullable OffsetDateTime createdAt) {

CaseInsensitiveMap<ContextVariable<?>> metadata = new CaseInsensitiveMap<>();

if (id != null) {
metadata.put(ID, ContextVariable.of(id));
}
metadata.put(ID, ContextVariable.of(id));

if (usage != null) {
metadata.put(USAGE, ContextVariable.of(usage,
new ContextVariableTypeConverter.NoopConverter<>(Object.class)));
Expand Down Expand Up @@ -107,7 +106,6 @@ public CaseInsensitiveMap<ContextVariable<?>> getMetadata() {
*
* @return The id of the result of the function invocation.
*/
@Nullable
public String getId() {
ContextVariable<?> id = metadata.get(ID);
if (id == null) {
Expand Down

0 comments on commit 9a1c6e3

Please sign in to comment.