From a7b052d4a2abe3706ed369a10e8c47c17d02e97b Mon Sep 17 00:00:00 2001 From: Andrii Kurdiumov Date: Mon, 19 Aug 2024 13:16:44 +0500 Subject: [PATCH] Add support for IEnumerable in the results --- .../DbContextManager.cs | 3 ++ .../SqlMarshal.CompilationTests.csproj | 2 +- SqlMarshal.Tests/SqlMarshal.Tests.csproj | 2 +- .../StoredProcedureGenerationTests.cs | 50 ++++++++++++++++++- SqlMarshal/ClassGenerationContext.cs | 2 +- SqlMarshal/Extensions.cs | 6 ++- SqlMarshal/Generator.cs | 2 +- SqlMarshal/MethodGenerationContext.cs | 4 +- 8 files changed, 64 insertions(+), 7 deletions(-) diff --git a/SqlMarshal.CompilationTests/DbContextManager.cs b/SqlMarshal.CompilationTests/DbContextManager.cs index cb39fde..da4d3e0 100644 --- a/SqlMarshal.CompilationTests/DbContextManager.cs +++ b/SqlMarshal.CompilationTests/DbContextManager.cs @@ -24,6 +24,9 @@ public DbContextManager(PersonDbContext context) [SqlMarshal("persons_list")] public partial IList<(int Id, string Name)> GetTupleResult(); + [SqlMarshal("persons_list")] + public partial IEnumerable GetEnumerableResult(); + [SqlMarshal("persons_list")] public partial Task> GetResultAsync(); diff --git a/SqlMarshal.CompilationTests/SqlMarshal.CompilationTests.csproj b/SqlMarshal.CompilationTests/SqlMarshal.CompilationTests.csproj index 635a90e..9c2595d 100644 --- a/SqlMarshal.CompilationTests/SqlMarshal.CompilationTests.csproj +++ b/SqlMarshal.CompilationTests/SqlMarshal.CompilationTests.csproj @@ -2,7 +2,7 @@ Exe - net6.0 + net6.0;net8.0 link enable true diff --git a/SqlMarshal.Tests/SqlMarshal.Tests.csproj b/SqlMarshal.Tests/SqlMarshal.Tests.csproj index 030fbea..adc74b9 100644 --- a/SqlMarshal.Tests/SqlMarshal.Tests.csproj +++ b/SqlMarshal.Tests/SqlMarshal.Tests.csproj @@ -1,7 +1,7 @@  - net6.0 + net6.0;net8.0 false CS1591;SA1600 diff --git a/SqlMarshal.Tests/StoredProcedureGenerationTests.cs b/SqlMarshal.Tests/StoredProcedureGenerationTests.cs index df8140d..c9e3103 100644 --- a/SqlMarshal.Tests/StoredProcedureGenerationTests.cs +++ b/SqlMarshal.Tests/StoredProcedureGenerationTests.cs @@ -13,7 +13,7 @@ namespace SqlMarshal.Tests; public class StoredProcedureGenerationTests : CodeGenerationTestBase { [TestMethod] - public void MapResultSetToProcedure() + public void MapResultSetFromIListToProcedure() { string source = @" namespace Foo @@ -60,6 +60,54 @@ public partial IList M() Assert.AreEqual(expectedOutput, output); } + [TestMethod] + public void MapResultSetFromIEnumerableToProcedure() + { + string source = @" +namespace Foo +{ + class C + { + [SqlMarshal(""sp_TestSP"")] + public partial IEnumerable M() + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Data.Common; + using System.Linq; + using Microsoft.EntityFrameworkCore; + using Microsoft.EntityFrameworkCore.Storage; + + partial class C + { + public partial IEnumerable M() + { + var connection = this.dbContext.Database.GetDbConnection(); + using var command = connection.CreateCommand(); + + var sqlQuery = @""sp_TestSP""; + var result = this.dbContext.Items.FromSqlRaw(sqlQuery).ToList(); + return result; + } + } +}"; + Assert.AreEqual(expectedOutput, output); + } + [TestMethod] public void MapSingleObjectToProcedure() { diff --git a/SqlMarshal/ClassGenerationContext.cs b/SqlMarshal/ClassGenerationContext.cs index 818d7d5..195f0c2 100644 --- a/SqlMarshal/ClassGenerationContext.cs +++ b/SqlMarshal/ClassGenerationContext.cs @@ -46,7 +46,7 @@ public ClassGenerationContext( public bool HasEfCore => this.ConnectionField == null && this.Methods.All(_ => _.ConnectionParameter == null); - public bool HasCollections => !this.HasEfCore || this.Methods.Any(_ => _.IsList && (IsScalarType(_.ItemType) || IsTuple(_.ItemType))); + public bool HasCollections => !this.HasEfCore || this.Methods.Any(_ => (_.IsList || _.IsEnumerable) && (IsScalarType(_.ItemType) || IsTuple(_.ItemType))); private static IFieldSymbol? GetConnectionField(INamedTypeSymbol classSymbol) { diff --git a/SqlMarshal/Extensions.cs b/SqlMarshal/Extensions.cs index 93562cb..cebf13d 100644 --- a/SqlMarshal/Extensions.cs +++ b/SqlMarshal/Extensions.cs @@ -151,11 +151,15 @@ internal static ITypeSymbol GetUnderlyingType(ITypeSymbol returnType) return returnType; } + internal static bool IsList(ITypeSymbol returnType) => returnType.Name == "IList" || returnType.Name == "List"; + + internal static bool IsEnumerable(ITypeSymbol returnType) => returnType.Name == "IEnumerable"; + internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType) { if (returnType is INamedTypeSymbol namedTypeSymbol) { - if (returnType.Name != "IList" && returnType.Name != "List") + if (!IsList(returnType) && !IsEnumerable(returnType)) { return returnType; } diff --git a/SqlMarshal/Generator.cs b/SqlMarshal/Generator.cs index 89d1d7d..25fc021 100644 --- a/SqlMarshal/Generator.cs +++ b/SqlMarshal/Generator.cs @@ -751,7 +751,7 @@ private void ProcessMethod( var signature = $"({string.Join(", ", originalParameters.Select((parameterSymbol, index) => GetParameterDeclaration(methodSymbol, parameterSymbol, index)))})"; var itemType = methodGenerationContext.ItemType; var getConnection = this.GetConnectionStatement(methodGenerationContext); - var isList = methodGenerationContext.IsList; + var isList = methodGenerationContext.IsList || methodGenerationContext.IsEnumerable; var isScalarType = IsScalarType(UnwrapNullableType(returnType)) || returnType.SpecialType == SpecialType.System_Void || returnType.Name == "Task"; diff --git a/SqlMarshal/MethodGenerationContext.cs b/SqlMarshal/MethodGenerationContext.cs index fed11b1..8cd2f0a 100644 --- a/SqlMarshal/MethodGenerationContext.cs +++ b/SqlMarshal/MethodGenerationContext.cs @@ -76,7 +76,9 @@ internal MethodGenerationContext(ClassGenerationContext classGenerationContext, internal ITypeSymbol ReturnType => this.MethodSymbol.ReturnType.UnwrapTaskType(); - internal bool IsList => this.ItemType != this.ReturnType; + internal bool IsList => IsList(this.ReturnType); + + internal bool IsEnumerable => IsEnumerable(this.ReturnType); internal ITypeSymbol ItemType => UnwrapListItem(this.ReturnType);