diff --git a/src/Marten/Internal/ValueTypeIdentifiedDocumentStorage.cs b/src/Marten/Internal/ValueTypeIdentifiedDocumentStorage.cs index 432b2c829b..708de4b0c9 100644 --- a/src/Marten/Internal/ValueTypeIdentifiedDocumentStorage.cs +++ b/src/Marten/Internal/ValueTypeIdentifiedDocumentStorage.cs @@ -39,7 +39,12 @@ public void SetIdentity(TDoc document, TSimple identity) public IIdentitySetter Inner { get; } } -internal class ValueTypeIdentifiedDocumentStorage: IDocumentStorage where TDoc : notnull where TSimple : notnull where TValueType : notnull +internal interface IValueTypeStorage +{ + IQueryHandler> BuildLoadManyHandler(TValueType[] keys); +} + +internal class ValueTypeIdentifiedDocumentStorage: IDocumentStorage, IValueTypeStorage where TDoc : notnull where TSimple : notnull where TValueType : notnull { private readonly Func _converter; private readonly Func _unwrapper; @@ -52,6 +57,12 @@ public ValueTypeIdentifiedDocumentStorage(ValueTypeInfo valueTypeInfo, IDocument _unwrapper = valueTypeInfo.UnWrapper(); } + public IQueryHandler> BuildLoadManyHandler(TValueType[] keys) + { + var ids = keys.Select(x => _unwrapper(x)).ToArray(); + return new LoadByIdArrayHandler(Inner, ids); + } + public IDocumentStorage Inner { get; } public void Apply(ICommandBuilder builder) => Inner.Apply(builder); diff --git a/src/Marten/Linq/QueryHandlers/LoadByIdHandler.cs b/src/Marten/Linq/QueryHandlers/LoadByIdHandler.cs index bbd9b45276..270b211614 100644 --- a/src/Marten/Linq/QueryHandlers/LoadByIdHandler.cs +++ b/src/Marten/Linq/QueryHandlers/LoadByIdHandler.cs @@ -1,6 +1,8 @@ #nullable enable +using System; using System.Data.Common; using System.IO; +using System.Linq; using System.Threading; using System.Threading.Tasks; using JasperFx.Core.Reflection; @@ -17,10 +19,12 @@ internal class LoadByIdHandler: IQueryHandler where T : notnull where { private readonly TId _id; private readonly IDocumentStorage storage; + private static readonly Type[] _identityTypes = [typeof(int), typeof(long), typeof(string), typeof(Guid)]; public LoadByIdHandler(IDocumentStorage documentStorage, TId id) { storage = documentStorage; + _id = id; } @@ -40,12 +44,24 @@ public void ConfigureCommand(ICommandBuilder sql, IMartenSession session) sql.Append(storage.FromObject); sql.Append(" as d where id = "); - sql.AppendParameter(_id); + if (_identityTypes.Contains(typeof(TId))) + { + sql.AppendParameter(_id); + } + else + { + var valueType = ValueTypeInfo.ForType(typeof(TId)); + typeof(Appender<,>).CloseAndBuildAs>(valueType, typeof(TId), valueType.SimpleType) + .Append(sql, _id); + } + storage.AddTenancyFilter(sql, session.TenantId); } + + public T Handle(DbDataReader reader, IMartenSession session) { var selector = (ISelector)storage.BuildSelector(session); @@ -68,3 +84,24 @@ public Task StreamJson(Stream stream, DbDataReader reader, CancellationToke return reader.As().StreamOne(stream, token); } } + +internal interface IAppender +{ + public void Append(ICommandBuilder builder, TId id); +} + +internal class Appender: IAppender +{ + private readonly ValueTypeInfo _valueType; + + public Appender(ValueTypeInfo valueType) + { + _valueType = valueType; + } + + public void Append(ICommandBuilder builder, TId id) + { + var simple = _valueType.UnWrapper()(id); + builder.AppendParameter(simple); + } +} diff --git a/src/Marten/Services/BatchQuerying/BatchedQuery.cs b/src/Marten/Services/BatchQuerying/BatchedQuery.cs index 6d6ef970ba..76b047cce1 100644 --- a/src/Marten/Services/BatchQuerying/BatchedQuery.cs +++ b/src/Marten/Services/BatchQuerying/BatchedQuery.cs @@ -1,14 +1,11 @@ -#nullable enable using System; using System.Collections.Generic; using System.Linq; using System.Threading; using System.Threading.Tasks; using JasperFx.Core.Reflection; -using JasperFx.Events; -using Marten.Events; -using Marten.Events.Querying; using Marten.Exceptions; +using Marten.Internal; using Marten.Internal.Sessions; using Marten.Internal.Storage; using Marten.Linq; @@ -32,6 +29,12 @@ public BatchedQuery(QuerySession parent) public IBatchEvents Events => this; + public Task Load(object id) where T : class + { + var loader = typeof(Loader<>).CloseAndBuildAs(id.GetType()); + return loader.Load(id, this); + } + public Task Load(string id) where T : class { return load(id); @@ -208,8 +211,22 @@ public Task Average(IQueryable queryable) where T : notnull return addItem(queryable, SingleValueMode.Average); } + private interface ILoader + { + Task Load(object id, BatchedQuery parent) where T : class; + } + + private class Loader: ILoader + { + public Task Load(object id, BatchedQuery parent) where T : class + { + return parent.load((TId)id); + } + } + internal class BatchLoadByKeys: IBatchLoadByKeys where TDoc : class { + private static readonly Type[] _identityTypes = [typeof(int), typeof(long), typeof(Guid), typeof(string)]; private readonly BatchedQuery _parent; public BatchLoadByKeys(BatchedQuery parent) @@ -219,6 +236,10 @@ public BatchLoadByKeys(BatchedQuery parent) public Task> ById(params TKey[] keys) { + if (typeof(TKey).IsNullable()) + throw new ArgumentOutOfRangeException(nameof(TKey), + "Cannot use nullable types as the TKey, you may need to explicitly define the generic argument"); + return load(keys); } @@ -229,8 +250,14 @@ public Task> ByIdList(IEnumerable keys) private Task> load(TKey[] keys) { - var storage = _parent.Parent.StorageFor(); - return _parent.AddItem(new LoadByIdArrayHandler(storage, keys)); + var storage = _parent.Parent.StorageFor(); + if (_identityTypes.Contains(typeof(TKey))) + { + return _parent.AddItem(new LoadByIdArrayHandler(storage, keys)); + } + + throw new ArgumentOutOfRangeException(nameof(keys), + "Marten cannot (yet) handle this identity type for this operation"); } } } diff --git a/src/Marten/Services/BatchQuerying/IBatchedQuery.cs b/src/Marten/Services/BatchQuerying/IBatchedQuery.cs index db6e1b28c5..75ef3e9162 100644 --- a/src/Marten/Services/BatchQuerying/IBatchedQuery.cs +++ b/src/Marten/Services/BatchQuerying/IBatchedQuery.cs @@ -177,6 +177,14 @@ public interface IBatchedQuery /// Task Load(Guid id) where T : class; + /// + /// Load a single document of Type "T" by id + /// + /// + /// + /// + Task Load(object id) where T : class; + /// /// Load a one or more documents of Type "T" by id's /// diff --git a/src/ValueTypeTests/using_in_batch_queries.cs b/src/ValueTypeTests/using_in_batch_queries.cs new file mode 100644 index 0000000000..5cb1ffc376 --- /dev/null +++ b/src/ValueTypeTests/using_in_batch_queries.cs @@ -0,0 +1,35 @@ +using System.Threading.Tasks; +using Marten.Services.BatchQuerying; +using Marten.Testing.Harness; +using Shouldly; + +namespace ValueTypeTests; + +public class using_in_batch_queries : OneOffConfigurationsContext +{ + [Fact] + public async Task load_one_at_a_time() + { + var teacher1 = new Teacher(); + var teacher2 = new Teacher(); + var teacher3 = new Teacher(); + var teacher4 = new Teacher(); + + theSession.Store(teacher1, teacher2, teacher3, teacher4); + await theSession.SaveChangesAsync(); + + var loaded = await theSession.LoadAsync(teacher4.Id); + loaded.ShouldNotBeNull(); + + var batch = theSession.CreateBatchQuery(); + var teacher1_task = batch.Load(teacher1.Id); + var teacher2_task = batch.Load(teacher2.Id); + var teacher3_task = batch.Load(teacher3.Id); + + await batch.Execute(); + + (await teacher1_task).ShouldNotBeNull(); + (await teacher2_task).ShouldNotBeNull(); + (await teacher3_task).ShouldNotBeNull(); + } +}