Skip to content

Commit

Permalink
Java: Fix linting issues (microsoft#4800)
Browse files Browse the repository at this point in the history
Fix linting issues
  • Loading branch information
johnoliver committed Jan 31, 2024
1 parent 751da5e commit 03ae300
Show file tree
Hide file tree
Showing 79 changed files with 1,312 additions and 598 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.microsoft.semantickernel.chatcompletion.ChatMessageContent;
import com.microsoft.semantickernel.chatcompletion.StreamingChatMessageContent;
import com.microsoft.semantickernel.exceptions.AIException;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.hooks.KernelHooks;
import com.microsoft.semantickernel.hooks.PreChatCompletionEvent;
import com.microsoft.semantickernel.orchestration.FunctionResult;
Expand Down Expand Up @@ -56,9 +57,15 @@ public class OpenAIChatCompletion implements ChatCompletionService {
private static final Logger LOGGER = LoggerFactory.getLogger(OpenAIChatCompletion.class);
private final OpenAIAsyncClient client;
private final Map<String, ContextVariable<?>> attributes;

@Nullable
private final String serviceId;

public OpenAIChatCompletion(OpenAIAsyncClient client, String modelId, String serviceId) {
public OpenAIChatCompletion(
OpenAIAsyncClient client,
String modelId,
@Nullable
String serviceId) {
this.serviceId = serviceId;
this.client = client;
this.attributes = new HashMap<>();
Expand All @@ -75,15 +82,19 @@ public Map<String, ContextVariable<?>> getAttributes() {
}

@Override
@Nullable
public String getServiceId() {
return serviceId;
}

@Override
public Mono<List<ChatMessageContent>> getChatMessageContentsAsync(
ChatHistory chatHistory,
@Nullable
PromptExecutionSettings promptExecutionSettings,
@Nullable
Kernel kernel,
@Nullable
KernelHooks kernelHooks) {

List<ChatRequestMessage> chatRequestMessages = getChatRequestMessages(chatHistory);
Expand Down Expand Up @@ -112,6 +123,7 @@ public Mono<List<ChatMessageContent>> getChatMessageContentsAsync(
@Override
public Mono<List<ChatMessageContent>> getChatMessageContentsAsync(
String prompt,
@Nullable
PromptExecutionSettings promptExecutionSettings,
Kernel kernel,
@Nullable
Expand All @@ -138,7 +150,9 @@ public Mono<List<ChatMessageContent>> getChatMessageContentsAsync(String prompt,

private Mono<List<ChatMessageContent>> internalChatMessageContentsAsync(
List<ChatRequestMessage> chatRequestMessages,
@Nullable
List<FunctionDefinition> functions,
@Nullable
PromptExecutionSettings promptExecutionSettings,
KernelHooks kernelHooks) {

Expand Down Expand Up @@ -242,7 +256,8 @@ private Function<ChatResponseMessage, ChatResponseCollector> accumulateResponse(
* {"type":"function", "function": {"name":"search-search", "parameters": {"query":"Banksy"}}}
* where 'name' is <plugin name '-' function name>.
*/
private Mono<StreamingChatMessageContent> invokeTool(Kernel kernel, String json) {
@SuppressWarnings("UnusedMethod")
private Mono<StreamingChatMessageContent<?>> invokeTool(Kernel kernel, String json) {
try {
ObjectMapper mapper = new ObjectMapper();
JsonNode jsonNode = mapper.readTree(json);
Expand All @@ -254,8 +269,10 @@ private Mono<StreamingChatMessageContent> invokeTool(Kernel kernel, String json)
if (result != null) {
return result.map(contextVariable -> {
String content = contextVariable.getResult();
return new StreamingChatMessageContent(AuthorRole.TOOL, content).setModelId(
id);
if (content == null) {
throw new SKException("Function result must not be null");
}
return new StreamingChatMessageContent<>(AuthorRole.TOOL, content, id);
});
}
}
Expand All @@ -268,6 +285,7 @@ private Mono<StreamingChatMessageContent> invokeTool(Kernel kernel, String json)
/*
* The jsonNode should represent: {"name":"search-search", "parameters": {"query":"Banksy"}}}
*/
@SuppressWarnings("StringSplitter")
private Mono<FunctionResult<String>> invokeFunction(Kernel kernel, JsonNode jsonNode) {
String name = jsonNode.get("name").asText();
String[] parts = name.split("-");
Expand Down Expand Up @@ -298,23 +316,28 @@ private Mono<FunctionResult<String>> invokeFunction(Kernel kernel, JsonNode json
private static ChatCompletionsOptions getCompletionsOptions(
ChatCompletionService chatCompletionService,
List<ChatRequestMessage> chatRequestMessages,
@Nullable
List<FunctionDefinition> functions,
@Nullable
PromptExecutionSettings promptExecutionSettings) {

ChatCompletionsOptions options = new ChatCompletionsOptions(chatRequestMessages)
.setModel(chatCompletionService.getModelId());

if (promptExecutionSettings == null) {
return options;
}

if (promptExecutionSettings.getResultsPerPrompt() < 1
|| promptExecutionSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST, String.format("Results per prompt must be in range between 1 and %d, inclusive.", MAX_RESULTS_PER_PROMPT));
|| promptExecutionSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format("Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT));
}

List<ChatCompletionsToolDefinition> toolDefinitions =
chatCompletionsToolDefinitions(promptExecutionSettings.getToolCallBehavior(), functions);
List<ChatCompletionsToolDefinition> toolDefinitions =
chatCompletionsToolDefinitions(promptExecutionSettings.getToolCallBehavior(),
functions);
if (toolDefinitions != null && !toolDefinitions.isEmpty()) {
options.setTools(toolDefinitions);
// TODO: options.setToolChoices(toolChoices);
Expand Down Expand Up @@ -351,18 +374,22 @@ private static ChatCompletionsOptions getCompletionsOptions(
return options;
}

@SuppressWarnings("StringSplitter")
private static List<ChatCompletionsToolDefinition> chatCompletionsToolDefinitions(
@Nullable
ToolCallBehavior toolCallBehavior,
@Nullable
List<FunctionDefinition> functions) {

if (functions == null || functions.isEmpty()) {
return Collections.emptyList();
}

if (toolCallBehavior == null || !(toolCallBehavior.kernelFunctionsEnabled() || toolCallBehavior.autoInvokeEnabled())) {

if (toolCallBehavior == null || !(toolCallBehavior.kernelFunctionsEnabled()
|| toolCallBehavior.autoInvokeEnabled())) {
return Collections.emptyList();
}

return functions.stream()
.filter(function -> {
String[] parts = function.getName().split("-");
Expand All @@ -384,6 +411,9 @@ private static List<ChatRequestMessage> getChatRequestMessages(ChatHistory chatH
.map(message -> {
AuthorRole authorRole = message.getAuthorRole();
String content = message.getContent();
if (content == null) {
throw new SKException("ChatMessageContent content must not be null");
}
return getChatRequestMessage(authorRole, content);
})
.collect(Collectors.toList());
Expand All @@ -405,7 +435,7 @@ static ChatRequestMessage getChatRequestMessage(
return new ChatRequestToolMessage(content, null);
default:
LOGGER.debug("Unexpected author role: " + authorRole);
return null;
throw new SKException("Unexpected author role: " + authorRole);
}

}
Expand Down Expand Up @@ -459,7 +489,10 @@ public ChatMessageContent toChatMessageContent() {
private static class ToolContentBuffer implements
ContentBuffer<ChatCompletionsFunctionToolCall> {

@Nullable
private String id = null;

@Nullable
private String name = null;
private List<String> arguments = new ArrayList<>();

Expand Down Expand Up @@ -575,6 +608,17 @@ public static class Builder extends ChatCompletionService.Builder {

@Override
public OpenAIChatCompletion build() {

if (this.client == null) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"OpenAI client must be provided");
}

if (this.modelId == null || modelId.isEmpty()) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"OpenAI model id must be provided");
}

return new OpenAIChatCompletion(client, modelId, serviceId);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,22 @@

import com.azure.ai.openai.models.ChatRequestMessage;
import com.azure.ai.openai.models.FunctionDefinition;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;

class ParsedPrompt {

private final List<ChatRequestMessage> chatRequestMessages;
private final List<FunctionDefinition> functions;

protected ParsedPrompt(List<ChatRequestMessage> parsedMessages,
@Nullable
List<FunctionDefinition> parsedFunctions) {
this.chatRequestMessages = parsedMessages;
if (parsedFunctions == null) {
parsedFunctions = new ArrayList<>();
}
this.functions = parsedFunctions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,9 @@ private static List<FunctionDefinition> getFunctionDefinitions(String prompt) {
// },
// },
//}
assert functionDefinition != null;
if (functionDefinition == null) {
throw new SKException("Failed to parse function definition");
}
if (!parameters.isEmpty()) {
StringBuilder sb = new StringBuilder(
"{\"type\": \"object\", \"properties\": {");
Expand Down Expand Up @@ -207,11 +209,11 @@ private static ChatRequestMessage getChatRequestMessage(
String role,
String content) {
try {
AuthorRole authorRole = AuthorRole.valueOf(role.toUpperCase());
AuthorRole authorRole = AuthorRole.valueOf(role.toUpperCase(Locale.ROOT));
return OpenAIChatCompletion.getChatRequestMessage(authorRole, content);
} catch (IllegalArgumentException e) {
LOGGER.debug("Unknown author role: " + role);
return null;
throw new SKException("Unknown author role: " + role);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class OpenAITextGenerationService implements TextGenerationService {

private final OpenAIAsyncClient client;
private final Map<String, ContextVariable<?>> attributes;
@Nullable
private final String serviceId;

/// <summary>
Expand All @@ -38,7 +39,9 @@ public class OpenAITextGenerationService implements TextGenerationService {
/// <param name="loggerFactory">The <see cref="ILoggerFactory"/> to use for logging. If null, no logging will be performed.</param>
public OpenAITextGenerationService(
OpenAIAsyncClient client,
String modelId, String serviceId) {
String modelId,
@Nullable
String serviceId) {
this.serviceId = serviceId;
this.client = client;
attributes = new HashMap<>();
Expand All @@ -55,13 +58,16 @@ public Map<String, ContextVariable<?>> getAttributes() {
}

@Override
@Nullable
public String getServiceId() {
return serviceId;
}

@Override
public Mono<List<TextContent>> getTextContentsAsync(String prompt,
@Nullable PromptExecutionSettings executionSettings, @Nullable Kernel kernel) {
public Mono<List<TextContent>> getTextContentsAsync(
String prompt,
@Nullable PromptExecutionSettings executionSettings,
@Nullable Kernel kernel) {
return this.internalCompleteTextAsync(prompt, executionSettings);
}

Expand All @@ -78,6 +84,7 @@ public Flux<StreamingTextContent> getStreamingTextContentsAsync(

protected Mono<List<TextContent>> internalCompleteTextAsync(
String text,
@Nullable
PromptExecutionSettings requestSettings) {

CompletionsOptions completionsOptions = getCompletionsOptions(text, requestSettings);
Expand Down Expand Up @@ -118,7 +125,9 @@ public static Map<String, ContextVariable<?>> buildMetadata(
}

private CompletionsOptions getCompletionsOptions(
String text, PromptExecutionSettings requestSettings) {
String text,
@Nullable
PromptExecutionSettings requestSettings) {
if (requestSettings == null) {
return new CompletionsOptions(Collections.singletonList(text))
.setMaxTokens(PromptExecutionSettings.DEFAULT_MAX_TOKENS);
Expand All @@ -127,8 +136,10 @@ private CompletionsOptions getCompletionsOptions(
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST, "Max tokens must be >0");
}
if (requestSettings.getResultsPerPrompt() < 1
|| requestSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST, String.format("Results per prompt must be in range between 1 and %d, inclusive.", MAX_RESULTS_PER_PROMPT));
|| requestSettings.getResultsPerPrompt() > MAX_RESULTS_PER_PROMPT) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
String.format("Results per prompt must be in range between 1 and %d, inclusive.",
MAX_RESULTS_PER_PROMPT));
}

CompletionsOptions options =
Expand All @@ -143,12 +154,6 @@ private CompletionsOptions getCompletionsOptions(
.setUser(requestSettings.getUser())
.setBestOf(requestSettings.getBestOf())
.setLogitBias(new HashMap<>());
/*
if (requestSettings instanceof ChatRequestSettings) {
options = options.setStop(requestSettings.getStopSequences());
}
*/
return options;
}

Expand All @@ -157,7 +162,18 @@ private CompletionsOptions getCompletionsOptions(
*/
public static class Builder extends TextGenerationService.Builder {

@Override
public TextGenerationService build() {

if (this.client == null) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"OpenAI client must be provided");
}
if (this.modelId == null) {
throw new AIException(AIException.ErrorCodes.INVALID_REQUEST,
"OpenAI model id must be provided");
}

return new OpenAITextGenerationService(
this.client,
this.modelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import com.microsoft.semantickernel.orchestration.contextvariables.KernelArguments;
import com.microsoft.semantickernel.plugin.KernelPlugin;
import com.microsoft.semantickernel.plugin.KernelPluginFactory;
import com.microsoft.semantickernel.samples.syntaxexamples.Example03_Arguments.StaticTextSkill;
import com.microsoft.semantickernel.samples.syntaxexamples.Example03_Arguments.StaticTextPlugin;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;

Expand All @@ -22,7 +22,7 @@ public void main() {

// Load native skill
KernelPlugin functionCollection =
KernelPluginFactory.createFromObject(new StaticTextSkill(), "text");
KernelPluginFactory.createFromObject(new StaticTextPlugin(), "text");

KernelArguments arguments = KernelArguments.builder()
.withInput("Today is: ")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public void main(WireMockRuntimeInfo wmRuntimeInfo) {
.endpoint("http://localhost:" + wmRuntimeInfo.getHttpPort())
.buildAsyncClient();

TextGenerationService textGenerationService = OpenAITextGenerationService.builder()
TextGenerationService textGenerationService = TextGenerationService.builder()
.withOpenAIAsyncClient(client)
.withModelId("text-davinci-003")
.build();
Expand All @@ -55,6 +55,7 @@ public void main(WireMockRuntimeInfo wmRuntimeInfo) {
""".stripIndent();

var excuseFunction = new KernelFunctionFromPrompt.Builder()
.withName("Excuse")
.withTemplate(promptTemplate)
.withDefaultExecutionSettings(
new PromptExecutionSettings.Builder()
Expand Down
4 changes: 4 additions & 0 deletions java/samples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,10 @@
<target>${maven.compiler.release}</target>
<encoding>${project.build.sourceEncoding}</encoding>
<showWarnings>true</showWarnings>
<compilerArgs>
<arg></arg>
<compilerArg></compilerArg>
</compilerArgs>
</configuration>
</plugin>

Expand Down
Loading

0 comments on commit 03ae300

Please sign in to comment.