Skip to content

Commit

Permalink
Java: Fix chat message parsing (#5822)
Browse files Browse the repository at this point in the history
### Motivation and Context

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [ ] The code builds clean without any errors or warnings
- [ ] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [ ] All unit tests pass, and I have added new tests where possible
- [ ] I didn't break anyone 😄
  • Loading branch information
johnoliver committed Apr 9, 2024
1 parent 32c7e45 commit 0afe9c5
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 12 deletions.
8 changes: 7 additions & 1 deletion java/semantickernel-api/pom.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<parent>
Expand Down Expand Up @@ -50,6 +51,10 @@
<groupId>com.github.spotbugs</groupId>
<artifactId>spotbugs-annotations</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-text</artifactId>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
Expand All @@ -72,6 +77,7 @@
<artifactId>mockito-junit-jupiter</artifactId>
<version>5.11.0</version>
</dependency>

</dependencies>

<build>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.Optional;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.commons.text.StringEscapeUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -24,6 +25,7 @@ public class ContextVariableTypeConverter<T> {
ContextVariableTypeConverter.class);

public interface ToPromptStringFunction<T> {

String toPromptString(ContextVariableTypes types, T t);
}

Expand Down Expand Up @@ -389,4 +391,18 @@ public ContextVariableTypeConverter<T> build() {
fromPromptString);
}
}

/**
* To be used when toPromptString is called
*
* @param value the value to escape
* @return the escaped value
*/
@Nullable
public static String escapeXmlString(@Nullable String value) {
if (value == null) {
return null;
}
return StringEscapeUtils.escapeXml11(value);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private static String toXmlString(ChatHistory chatHistory) {
.stream()
.map(message -> String.format("<message role=\"%s\">%s</message>",
message.getAuthorRole(),
message.getContent()))
escapeXmlString(message.getContent())))
.collect(Collectors.joining("\n"));

return String.format("<messages>%n%s%n</messages>", messages);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,15 @@ public CollectionVariableContextVariableTypeConverter() {
this(",");
}

@SuppressWarnings("NullAway")
public static ToPromptStringFunction<Collection> getString(String delimiter) {
return (types, collection) -> {
return (String) collection
String formatted = (String) collection
.stream()
.map(t -> getString(types, t))
.collect(Collectors.joining(delimiter));

return escapeXmlString(formatted);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public PrimitiveVariableContextVariableTypeConverter(
super(
clazz,
s -> convert(s, clazz),
toPromptString,
s -> escapeXmlString(toPromptString.apply(s)),
fromPromptString);
this.fromObject = fromObject;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public StringVariableContextVariableTypeConverter() {
super(
String.class,
s -> convert(s, String.class),
Object::toString,
ContextVariableTypeConverter::escapeXmlString,
s -> s);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;
import com.microsoft.semantickernel.services.textcompletion.TextContent;
import javax.annotation.Nullable;

/**
* A converter for a context variable type. This class is used to convert objects to and from a
Expand All @@ -21,10 +22,18 @@ public TextContentVariableContextVariableTypeConverter() {
super(
TextContent.class,
s -> convert(s, TextContent.class),
TextContent::getContent,
TextContentVariableContextVariableTypeConverter::escapeXmlStringValue,
x -> {
throw new UnsupportedOperationException(
"TextContentVariableContextVariableTypeConverter does not support fromPromptString");
});
}

@Nullable
public static String escapeXmlStringValue(@Nullable TextContent value) {
if (value == null) {
return null;
}
return escapeXmlString(value.getContent());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class PromptExecutionSettings {

/**
* The default for {@link #getFrequencyPenalty()} if
* {@link Builder#withMaxTokens(doulbe) frequency_penalty} is not provided. Defaults to
* {@link Builder#withFrequencyPenalty(double)} frequency_penalty} is not provided. Defaults to
* {@code 0.0}
*/
public static final double DEFAULT_FREQUENCY_PENALTY = 0.0;
Expand Down Expand Up @@ -97,6 +97,7 @@ public class PromptExecutionSettings {
private final String user;
private final List<String> stopSequences;
private final Map<Integer, Integer> tokenSelectionBiases;
@Nullable
private final ResponseFormat responseFormat;

/**
Expand All @@ -114,8 +115,7 @@ public class PromptExecutionSettings {
* @param user The user to associate with the prompt execution.
* @param stopSequences The stop sequences to use for prompt execution.
* @param tokenSelectionBiases The token selection biases to use for prompt execution.
* @param responseFormat The response format to use for prompt execution @{link
* ResponseFormat}.
* @param responseFormat The response format to use for prompt execution {@link ResponseFormat}.
*/
@JsonCreator
public PromptExecutionSettings(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.contextvariables.ContextVariable;
import com.microsoft.semantickernel.contextvariables.ContextVariableType;
import com.microsoft.semantickernel.contextvariables.ContextVariableTypeConverter;
import com.microsoft.semantickernel.exceptions.SKException;
import com.microsoft.semantickernel.orchestration.InvocationContext;
import com.microsoft.semantickernel.orchestration.ToolCallBehavior;
Expand Down Expand Up @@ -85,7 +86,8 @@ public Object resolve(Object context, String name) {
if ("role".equalsIgnoreCase(name)) {
return ((ChatMessageContent) context).getAuthorRole().name();
} else if ("content".equalsIgnoreCase(name)) {
return ((ChatMessageContent) context).getContent();
return ContextVariableTypeConverter
.escapeXmlString(((ChatMessageContent) context).getContent());
}
}
return UNRESOLVED;
Expand All @@ -104,7 +106,7 @@ public Object resolve(Object context) {
"<message role=\"%s\">%s</message>",
((ChatMessageContent) context).getAuthorRole().toString()
.toLowerCase(Locale.ROOT),
content);
ContextVariableTypeConverter.escapeXmlString(content));
}
return UNRESOLVED;
}
Expand Down
8 changes: 7 additions & 1 deletion java/semantickernel-bom/pom.xml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>

<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>com.microsoft.semantic-kernel</groupId>
Expand Down Expand Up @@ -152,6 +153,11 @@
<artifactId>spotbugs-annotations</artifactId>
<version>4.8.3</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-text</artifactId>
<version>1.11.0</version>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down

0 comments on commit 0afe9c5

Please sign in to comment.