Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Copilot Chat: Java: add Example15_MemorySkill #1999

Merged
merged 4 commits into from
Aug 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
// Copyright (c) Microsoft. All rights reserved.
package com.microsoft.semantickernel.samples.syntaxexamples;

import com.azure.ai.openai.OpenAIAsyncClient;
import com.microsoft.semantickernel.Kernel;
import com.microsoft.semantickernel.KernelConfig;
import com.microsoft.semantickernel.ai.embeddings.EmbeddingGeneration;
import com.microsoft.semantickernel.builders.SKBuilders;
import com.microsoft.semantickernel.coreskills.TextMemorySkill;
import com.microsoft.semantickernel.memory.MemoryStore;
import com.microsoft.semantickernel.orchestration.SKContext;
import com.microsoft.semantickernel.orchestration.SKFunction;
import com.microsoft.semantickernel.samples.Config;
import com.microsoft.semantickernel.semanticfunctions.PromptTemplateConfig;
import com.microsoft.semantickernel.textcompletion.CompletionSKFunction;
import com.microsoft.semantickernel.textcompletion.TextCompletion;
import reactor.core.publisher.Mono;

import java.util.List;
import java.util.stream.Collectors;

public class Example15_MemorySkill
{
private static final String MEMORY_COLLECTION_NAME = "aboutMe";

public static void main(String[] args) { runAsync().block(); }
public static Mono<Void> runAsync() {
// ========= Create a kernel =========
OpenAIAsyncClient client = null;
try {
client = Config.ClientType.AZURE_OPEN_AI.getClient();
} catch (Exception e) {
return Mono.error(e);
}

TextCompletion textCompletionService =
SKBuilders.textCompletionService()
.build(client, "text-davinci-003");

EmbeddingGeneration textEmbeddingGenerationService =
SKBuilders.textEmbeddingGenerationService()
.build(client, "text-embedding-ada-002");

MemoryStore memoryStore = SKBuilders.memoryStore().build();

KernelConfig kernelConfig = SKBuilders.kernelConfig()
.addTextCompletionService("text-davinci-003", kernel -> textCompletionService)
.addTextEmbeddingsGenerationService("text-embedding-ada-002", kernel -> textEmbeddingGenerationService)
.build();

Kernel kernel = SKBuilders.kernel()
.withKernelConfig(kernelConfig)
.withMemoryStore(memoryStore)
.build();

// ========= Store memories using the kernel =========
kernel.getMemory().saveInformationAsync(MEMORY_COLLECTION_NAME, "My name is Andrea", "info1", null, null).block();
kernel.getMemory().saveInformationAsync(MEMORY_COLLECTION_NAME, "I work as a tourist operator", "info2", null, null).block();
kernel.getMemory().saveInformationAsync(MEMORY_COLLECTION_NAME, "I've been living in Seattle since 2005", "info3", null, null).block();
kernel.getMemory().saveInformationAsync(MEMORY_COLLECTION_NAME, "I visited France and Italy five times since 2015", "info4", null, null).block();

// ========= Store memories using semantic function =========

// Add Memory as a skill for other functions
TextMemorySkill memorySkill = new TextMemorySkill();
kernel.importSkill(memorySkill, "memory");

// Build a semantic function that saves info to memory
PromptTemplateConfig.CompletionConfig completionConfig = SKBuilders.completionConfig()
.temperature(0.2)
.topP(0.5)
.presencePenalty(0)
.frequencyPenalty(0)
.maxTokens(2000)
.build();

CompletionSKFunction saveFunctionDefinition = SKBuilders.completionFunctions(kernel)
.createFunction(
"{{memory.save $info}}",
"save",
"",
"save information to memory",
completionConfig
);

CompletionSKFunction memorySaver =
kernel.registerSemanticFunction(saveFunctionDefinition);

SKContext context = SKBuilders.context()
.with(kernel.getMemory())
.with(kernel.getSkills())
.build();

context.setVariable(TextMemorySkill.COLLECTION_PARAM, MEMORY_COLLECTION_NAME)
.setVariable(TextMemorySkill.KEY_PARAM, "info5")
.setVariable(TextMemorySkill.INFO_PARAM, "My family is from New York");

memorySaver.invokeAsync(context).block();

// ========= Test memory remember =========
System.out.println("========= Example: Recalling a Memory =========");

// create a new context to avoid using the variables from the previous example
context = SKBuilders.context()
.with(kernel.getMemory())
.with(kernel.getSkills())
.build();

String answer = memorySkill.retrieveAsync(MEMORY_COLLECTION_NAME, "info1", context).block();
System.out.printf("Memory associated with 'info1': %s%n", answer);
/*
Output:
"Memory associated with 'info1': My name is Andrea
*/

// ========= Test memory recall =========
System.out.println("========= Example: Recalling an Idea =========");

List<String> answers = memorySkill.recallAsync(
"Where did I grow up?",
MEMORY_COLLECTION_NAME,
0,
2,
context).block();

System.out.println("Ask: Where did I grow up?");
System.out.printf("Answer:%n\t%s%n", answers.stream().collect(Collectors.joining("\",\"", "[\"", "\"]")));

answers = memorySkill.recallAsync(
"Where do I live?",
MEMORY_COLLECTION_NAME,
0,
2,
context).block();

System.out.println("Ask: Where do I live?");
System.out.printf("Answer:%n\t%s%n", answers.stream().collect(Collectors.joining("\",\"", "[\"", "\"]")));

/*
Output:

Ask: where did I grow up?
Answer:
["My family is from New York","I've been living in Seattle since 2005"]

Ask: where do I live?
Answer:
["I've been living in Seattle since 2005","My family is from New York"]
*/

// ========= Use memory in a semantic function =========
System.out.println("========= Example: Using Recall in a Semantic Function =========");

// Build a semantic function that uses memory to find facts
String prompt =
"Consider only the facts below when answering questions.\n" +
"About me: {{memory.recall 'where did I grow up?'}}\n" +
"About me: {{memory.recall 'where do I live?'}}\n" +
"Question: {{$input}}\n" +
"Answer: ";

CompletionSKFunction recallFunctionDefinition =
kernel.getSemanticFunctionBuilder()
.createFunction(
prompt,
completionConfig
);

SKFunction<?> aboutMeOracle = kernel.registerSemanticFunction(recallFunctionDefinition);

context.setVariable(TextMemorySkill.COLLECTION_PARAM, MEMORY_COLLECTION_NAME)
.setVariable(TextMemorySkill.RELEVANCE_PARAM, "0.75");

SKContext result = aboutMeOracle.invokeAsync("Do I live in the same town where I grew up?", context, null).block();

System.out.println("Do I live in the same town where I grew up?\n");
System.out.println(result.getResult());

/*
Output:

Do I live in the same town where I grew up?

No, I do not live in the same town where I grew up since my family is from New York and I have been living in Seattle since 2005.
*/

// ========= Remove a memory =========
System.out.println("========= Example: Forgetting a Memory =========");

context.setVariable("fact1", "What is my name?")
.setVariable("fact2", "What do I do for a living?")
.setVariable(TextMemorySkill.RELEVANCE_PARAM, "0.75");

result = aboutMeOracle.invokeAsync("Tell me a bit about myself", context, null).block();

System.out.println("Tell me a bit about myself\n");
System.out.println(result.getResult());

/*
Approximate Output:
Tell me a bit about myself

My name is Andrea and my family is from New York. I work as a tourist operator.
*/

memorySkill.removeAsync(MEMORY_COLLECTION_NAME, "info1", context).block();

result = aboutMeOracle.invokeAsync("Tell me a bit about myself", context, null).block();

System.out.println("Tell me a bit about myself\n");
System.out.println(result.getResult());

/*
Approximate Output:
Tell me a bit about myself

I'm from a family originally from New York and I work as a tourist operator. I've been living in Seattle since 2005.
*/

return Mono.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ SKContext build(
ContextVariables getVariables();

/**
* Provides access to the contexts semantic memory
* Provides access to the context's semantic memory
*
* @return the semantic memory
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import com.microsoft.semantickernel.ai.embeddings.Embedding;
import com.microsoft.semantickernel.ai.embeddings.EmbeddingGeneration;
import com.microsoft.semantickernel.exceptions.NotSupportedException;

import reactor.core.publisher.Mono;
import reactor.util.function.Tuple2;
Expand Down Expand Up @@ -60,7 +59,15 @@ public Mono<String> saveInformationAsync(
MemoryRecord memoryRecord =
new MemoryRecord(
data, embeddings.iterator().next(), collection, null);
return _storage.upsertAsync(collection, memoryRecord);

return _storage.upsertAsync(collection, memoryRecord)
.onErrorResume(
e -> {
return _storage.createCollectionAsync(collection)
.then(
_storage.upsertAsync(
collection, memoryRecord));
});
});
}

Expand All @@ -72,8 +79,7 @@ public Mono<MemoryQueryResult> getAsync(String collection, String key, boolean w

@Override
public Mono<Void> removeAsync(@Nonnull String collection, @Nonnull String key) {
return Mono.error(new NotSupportedException("Pending implementation"));
// await this._storage.RemoveAsync(collection, key, cancel);
return _storage.removeAsync(collection, key);
}

private static final Function<
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,13 @@ public Mono<Void> deleteCollectionAsync(@Nonnull String collectionName) {
public Mono<String> upsertAsync(@Nonnull String collectionName, @Nonnull MemoryRecord record) {
// Contract:
// Does not guarantee that the collection exists.
Map<String, MemoryRecord> collection = getCollection(collectionName);
Map<String, MemoryRecord> collection = null;
try {
// getCollection throws MemoryException if the collection does not exist.
collection = getCollection(collectionName);
} catch (MemoryException e) {
return Mono.error(e);
}

String key = record.getMetadata().getId();
// Assumption is that MemoryRecord will always have a non-null id.
Expand Down Expand Up @@ -212,11 +218,14 @@ record -> {
}
}
});
return Mono.just(

List<Tuple2<MemoryRecord, Number>> result =
nearestMatches.stream()
.sorted(Comparator.comparingDouble(extractSimilarity).reversed())
.limit(limit)
.collect(Collectors.toList()));
.limit(Math.max(1, limit))
.collect(Collectors.toList());

return Mono.just(result);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
Expand Down Expand Up @@ -229,6 +235,11 @@ private static SKNativeTask<SKContext> getFunction(Method method, Object instanc

return mono.map(
it -> {
if (it instanceof Iterable) {
// Handle return from things like Mono<List<?>>
// from {{function 'input'}} as part of the prompt.
it = ((Iterable<?>) it).iterator().next();
dsgrieve marked this conversation as resolved.
Show resolved Hide resolved
}
if (it instanceof SKContext) {
return it;
} else {
Expand Down Expand Up @@ -291,7 +302,7 @@ private static String formErrorMessage(Method method, Parameter parameter) {
+ " was invoked with a required context variable missing and no default value.";
}

private static String getArgumentValue(
private static Object getArgumentValue(
Method method, SKContext context, Parameter parameter, Set<Parameter> inputArgs) {
String variableName = getGetVariableName(parameter);

Expand Down Expand Up @@ -346,7 +357,54 @@ private static String getArgumentValue(
"Unknown arg " + parameter.getName());
}
}
return arg;

SKFunctionParameters annotation = parameter.getAnnotation(SKFunctionParameters.class);
if (annotation == null || annotation.type() == null) {
return arg;
}
Class<?> type = annotation.type();
if (Number.class.isAssignableFrom(type)) {
arg = arg.replace(",", ".");
}

Object value = arg;
// Well-known types only
Function converter = converters.get(type);
if (converter != null) {
try {
value = converter.apply(arg);
} catch (NumberFormatException nfe) {
throw new AIException(
AIException.ErrorCodes.InvalidConfiguration,
"Invalid value for "
+ parameter.getName()
+ " expected "
+ type.getSimpleName()
+ " but got "
+ arg);
}
}
return value;
}

private static final Map<Class<?>, Function<String, ?>> converters = new HashMap<>();
johnoliver marked this conversation as resolved.
Show resolved Hide resolved

static {
converters.put(Boolean.class, Boolean::valueOf);
converters.put(boolean.class, Boolean::valueOf);
converters.put(Byte.class, Byte::parseByte);
converters.put(byte.class, Byte::parseByte);
converters.put(Integer.class, Integer::parseInt);
converters.put(int.class, Integer::parseInt);
converters.put(Long.class, Long::parseLong);
converters.put(long.class, Long::parseLong);
converters.put(Double.class, Double::parseDouble);
converters.put(double.class, Double::parseDouble);
converters.put(Float.class, Float::parseFloat);
converters.put(float.class, Float::parseFloat);
converters.put(Short.class, Short::parseShort);
converters.put(short.class, Short::parseShort);
converters.put(String.class, it -> it);
}

private static String getGetVariableName(Parameter parameter) {
Expand Down
Loading
Loading