Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions dotnet/src/Microsoft.Agents.AI.Mem0/Mem0Provider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ public sealed class Mem0Provider : AIContextProvider
private const string DefaultContextPrompt = "## Memories\nConsider the following memories when answering user questions:";

private readonly string _contextPrompt;
private readonly bool _enableSensitiveTelemetryData;

private readonly Mem0Client _client;
private readonly ILogger<Mem0Provider>? _logger;
Expand Down Expand Up @@ -64,6 +65,7 @@ public Mem0Provider(HttpClient httpClient, Mem0ProviderScope storageScope, Mem0P
this._client = new Mem0Client(httpClient);

this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt;
this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false;
this._storageScope = new Mem0ProviderScope(Throw.IfNull(storageScope));
this._searchScope = searchScope ?? storageScope;

Expand Down Expand Up @@ -114,6 +116,7 @@ public Mem0Provider(HttpClient httpClient, JsonElement serializedState, JsonSeri
this._client = new Mem0Client(httpClient);

this._contextPrompt = options?.ContextPrompt ?? DefaultContextPrompt;
this._enableSensitiveTelemetryData = options?.EnableSensitiveTelemetryData ?? false;

var jso = jsonSerializerOptions ?? Mem0JsonUtilities.DefaultOptions;
var state = serializedState.Deserialize(jso.GetTypeInfo(typeof(Mem0State))) as Mem0State;
Expand Down Expand Up @@ -158,17 +161,17 @@ public override async ValueTask<AIContext> InvokingAsync(InvokingContext context
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));
if (outputMessageText is not null)
{
this._logger.LogTrace(
"Mem0AIContextProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\nApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.",
queryText,
outputMessageText,
this.SanitizeLogData(queryText),
this.SanitizeLogData(outputMessageText),
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));
}
}

Expand All @@ -189,7 +192,7 @@ public override async ValueTask<AIContext> InvokingAsync(InvokingContext context
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));
return new AIContext();
}
}
Expand All @@ -215,7 +218,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio
this._storageScope.ApplicationId,
this._storageScope.AgentId,
this._storageScope.ThreadId,
this._storageScope.UserId);
this.SanitizeLogData(this._storageScope.UserId));
}
}

Expand Down Expand Up @@ -282,4 +285,6 @@ public Mem0State(Mem0ProviderScope storageScope, Mem0ProviderScope searchScope)
public Mem0ProviderScope StorageScope { get; set; }
public Mem0ProviderScope SearchScope { get; set; }
}

private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : "<redacted>";
}
6 changes: 6 additions & 0 deletions dotnet/src/Microsoft.Agents.AI.Mem0/Mem0ProviderOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,10 @@ public sealed class Mem0ProviderOptions
/// </summary>
/// <value>Defaults to "## Memories\nConsider the following memories when answering user questions:".</value>
public string? ContextPrompt { get; set; }

/// <summary>
/// Gets or sets a value indicating whether sensitive data such as user ids and user messages may appear in logs.
/// </summary>
/// <value>Defaults to <see langword="false"/>.</value>
public bool EnableSensitiveTelemetryData { get; set; }
}
16 changes: 10 additions & 6 deletions dotnet/src/Microsoft.Agents.AI/Memory/ChatHistoryMemoryProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ public sealed class ChatHistoryMemoryProvider : AIContextProvider, IDisposable
private readonly VectorStoreCollection<object, Dictionary<string, object?>> _collection;
private readonly int _maxResults;
private readonly string _contextPrompt;
private readonly bool _enableSensitiveTelemetryData;
private readonly ChatHistoryMemoryProviderOptions.SearchBehavior _searchTime;
private readonly AITool[] _tools;
private readonly ILogger<ChatHistoryMemoryProvider>? _logger;
Expand Down Expand Up @@ -130,6 +131,7 @@ private ChatHistoryMemoryProvider(
options ??= new ChatHistoryMemoryProviderOptions();
this._maxResults = options.MaxResults.HasValue ? Throw.IfLessThanOrEqual(options.MaxResults.Value, 0) : DefaultMaxResults;
this._contextPrompt = options.ContextPrompt ?? DefaultContextPrompt;
this._enableSensitiveTelemetryData = options.EnableSensitiveTelemetryData;
this._searchTime = options.SearchTime;
this._logger = loggerFactory?.CreateLogger<ChatHistoryMemoryProvider>();

Expand Down Expand Up @@ -216,7 +218,7 @@ public override async ValueTask<AIContext> InvokingAsync(InvokingContext context
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));
return new AIContext();
}
}
Expand Down Expand Up @@ -268,7 +270,7 @@ public override async ValueTask InvokedAsync(InvokedContext context, Cancellatio
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));
}
}

Expand Down Expand Up @@ -302,12 +304,12 @@ internal async Task<string> SearchTextAsync(string userQuestion, CancellationTok

this._logger?.LogTrace(
"ChatHistoryMemoryProvider: Search Results\nInput:{Input}\nOutput:{MessageText}\n ApplicationId: '{ApplicationId}', AgentId: '{AgentId}', ThreadId: '{ThreadId}', UserId: '{UserId}'.",
userQuestion,
formatted,
this.SanitizeLogData(userQuestion),
this.SanitizeLogData(formatted),
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));
return formatted;
}

Expand Down Expand Up @@ -387,7 +389,7 @@ internal async Task<string> SearchTextAsync(string userQuestion, CancellationTok
this._searchScope.ApplicationId,
this._searchScope.AgentId,
this._searchScope.ThreadId,
this._searchScope.UserId);
this.SanitizeLogData(this._searchScope.UserId));

return results;
}
Expand Down Expand Up @@ -475,6 +477,8 @@ public override JsonElement Serialize(JsonSerializerOptions? jsonSerializerOptio
return serializedState.Deserialize(jso.GetTypeInfo(typeof(ChatHistoryMemoryProviderState))) as ChatHistoryMemoryProviderState;
}

private string? SanitizeLogData(string? data) => this._enableSensitiveTelemetryData ? data : "<redacted>";

internal sealed class ChatHistoryMemoryProviderState
{
public ChatHistoryMemoryProviderScope? StorageScope { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ public sealed class ChatHistoryMemoryProviderOptions
/// </value>
public int? MaxResults { get; set; }

/// <summary>
/// Gets or sets a value indicating whether sensitive data such as user ids and user messages may appear in logs.
/// </summary>
/// <value>Defaults to <see langword="false"/>.</value>
public bool EnableSensitiveTelemetryData { get; set; }

/// <summary>
/// Behavior choices for the provider.
/// </summary>
Expand Down
105 changes: 104 additions & 1 deletion dotnet/tests/Microsoft.Agents.AI.Mem0.UnitTests/Mem0ProviderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync()
ThreadId = "thread",
UserId = "user"
};
var sut = new Mem0Provider(this._httpClient, storageScope, loggerFactory: this._loggerFactoryMock.Object);
var sut = new Mem0Provider(this._httpClient, storageScope, options: new() { EnableSensitiveTelemetryData = true }, loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext(new[] { new ChatMessage(ChatRole.User, "What is my name?") });

// Act
Expand Down Expand Up @@ -130,6 +130,60 @@ public async Task InvokingAsync_PerformsSearch_AndReturnsContextMessageAsync()
Times.Once);
}

[Theory]
[InlineData(false, false, 2)]
[InlineData(true, false, 2)]
[InlineData(false, true, 1)]
[InlineData(true, true, 1)]
public async Task InvokingAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsync(bool enableSensitiveTelemetryData, bool requestThrows, int expectedLogInvocations)
{
// Arrange
if (requestThrows)
{
this._handler.EnqueueEmptyInternalServerError();
}
else
{
this._handler.EnqueueJsonResponse("[ { \"id\": \"1\", \"memory\": \"Name is Caoimhe\", \"hash\": \"h\", \"metadata\": null, \"score\": 0.9, \"created_at\": \"2023-01-01T00:00:00Z\", \"updated_at\": null, \"user_id\": \"u\", \"app_id\": null, \"agent_id\": \"agent\", \"session_id\": \"thread\" } ]");
}

var storageScope = new Mem0ProviderScope
{
ApplicationId = "app",
AgentId = "agent",
ThreadId = "thread",
UserId = "user"
};
var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData };

var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object);
var invokingContext = new AIContextProvider.InvokingContext(new[] { new ChatMessage(ChatRole.User, "Who am I?") });

// Act
await sut.InvokingAsync(invokingContext, CancellationToken.None);

// Assert
Assert.Equal(expectedLogInvocations, this._loggerMock.Invocations.Count);
foreach (var logInvocation in this._loggerMock.Invocations)
{
var state = Assert.IsAssignableFrom<IReadOnlyList<KeyValuePair<string, object?>>>(logInvocation.Arguments[2]);
var userIdValue = state.First(kvp => kvp.Key == "UserId").Value;
Assert.Equal(enableSensitiveTelemetryData ? "user" : "<redacted>", userIdValue);

var inputValue = state.FirstOrDefault(kvp => kvp.Key == "Input").Value;
if (inputValue != null)
{
Assert.Equal(enableSensitiveTelemetryData ? "Who am I?" : "<redacted>", inputValue);
}

var messageTextValue = state.FirstOrDefault(kvp => kvp.Key == "MessageText").Value;
if (messageTextValue != null)
{
Assert.Equal(enableSensitiveTelemetryData ? "## Memories\nConsider the following memories when answering user questions:\nName is Caoimhe" : "<redacted>", messageTextValue);
}
}
}

[Fact]
public async Task InvokedAsync_PersistsAllowedMessagesAsync()
{
Expand Down Expand Up @@ -218,6 +272,55 @@ public async Task InvokedAsync_ShouldNotThrow_WhenStorageFailsAsync()
Times.Once);
}

[Theory]
[InlineData(false, false, 0)]
[InlineData(true, false, 0)]
[InlineData(false, true, 1)]
[InlineData(true, true, 1)]
public async Task InvokedAsync_LogsUserIdBasedOnEnableSensitiveTelemetryDataAsync(bool enableSensitiveTelemetryData, bool requestThrows, int expectedLogCount)
{
// Arrange
if (requestThrows)
{
this._handler.EnqueueEmptyInternalServerError();
}
else
{
this._handler.EnqueueJsonResponse("[ { \"id\": \"1\", \"memory\": \"Name is Caoimhe\", \"hash\": \"h\", \"metadata\": null, \"score\": 0.9, \"created_at\": \"2023-01-01T00:00:00Z\", \"updated_at\": null, \"user_id\": \"u\", \"app_id\": null, \"agent_id\": \"agent\", \"session_id\": \"thread\" } ]");
}

var storageScope = new Mem0ProviderScope
{
ApplicationId = "app",
AgentId = "agent",
ThreadId = "thread",
UserId = "user"
};

var options = new Mem0ProviderOptions { EnableSensitiveTelemetryData = enableSensitiveTelemetryData };
var sut = new Mem0Provider(this._httpClient, storageScope, options: options, loggerFactory: this._loggerFactoryMock.Object);
var requestMessages = new List<ChatMessage>
{
new(ChatRole.User, "User text")
};
var responseMessages = new List<ChatMessage>
{
new(ChatRole.Assistant, "Assistant text")
};

// Act
await sut.InvokedAsync(new AIContextProvider.InvokedContext(requestMessages, aiContextProviderMessages: null) { ResponseMessages = responseMessages });

// Assert
Assert.Equal(expectedLogCount, this._loggerMock.Invocations.Count);
foreach (var logInvocation in this._loggerMock.Invocations)
{
var state = Assert.IsAssignableFrom<IReadOnlyList<KeyValuePair<string, object?>>>(logInvocation.Arguments[2]);
var userIdValue = state.First(kvp => kvp.Key == "UserId").Value;
Assert.Equal(enableSensitiveTelemetryData ? "user" : "<redacted>", userIdValue);
}
}

[Fact]
public async Task ClearStoredMemoriesAsync_SendsDeleteWithQueryAsync()
{
Expand Down
Loading
Loading