Skip to content

Commit

Permalink
Add support for IEnumerable in the results
Browse files Browse the repository at this point in the history
  • Loading branch information
kant2002 committed Aug 19, 2024
1 parent d2fb57c commit a7b052d
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 7 deletions.
3 changes: 3 additions & 0 deletions SqlMarshal.CompilationTests/DbContextManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ public DbContextManager(PersonDbContext context)
[SqlMarshal("persons_list")]
public partial IList<(int Id, string Name)> GetTupleResult();

[SqlMarshal("persons_list")]
public partial IEnumerable<PersonDbContext.Person> GetEnumerableResult();

[SqlMarshal("persons_list")]
public partial Task<IList<PersonDbContext.Person>> GetResultAsync();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net6.0</TargetFramework>
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>
<TrimMode>link</TrimMode>
<Nullable>enable</Nullable>
<EmitCompilerGeneratedFiles>true</EmitCompilerGeneratedFiles>
Expand Down
2 changes: 1 addition & 1 deletion SqlMarshal.Tests/SqlMarshal.Tests.csproj
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">

<PropertyGroup>
<TargetFramework>net6.0</TargetFramework>
<TargetFrameworks>net6.0;net8.0</TargetFrameworks>

<IsPackable>false</IsPackable>
<NoWarn>CS1591;SA1600</NoWarn>
Expand Down
50 changes: 49 additions & 1 deletion SqlMarshal.Tests/StoredProcedureGenerationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace SqlMarshal.Tests;
public class StoredProcedureGenerationTests : CodeGenerationTestBase
{
[TestMethod]
public void MapResultSetToProcedure()
public void MapResultSetFromIListToProcedure()
{
string source = @"
namespace Foo
Expand Down Expand Up @@ -60,6 +60,54 @@ public partial IList<Item> M()
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void MapResultSetFromIEnumerableToProcedure()
{
string source = @"
namespace Foo
{
class C
{
[SqlMarshal(""sp_TestSP"")]
public partial IEnumerable<Item> M()
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591
namespace Foo
{
using System;
using System.Data.Common;
using System.Linq;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
partial class C
{
public partial IEnumerable<Item> M()
{
var connection = this.dbContext.Database.GetDbConnection();
using var command = connection.CreateCommand();
var sqlQuery = @""sp_TestSP"";
var result = this.dbContext.Items.FromSqlRaw(sqlQuery).ToList();
return result;
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void MapSingleObjectToProcedure()
{
Expand Down
2 changes: 1 addition & 1 deletion SqlMarshal/ClassGenerationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public ClassGenerationContext(

public bool HasEfCore => this.ConnectionField == null && this.Methods.All(_ => _.ConnectionParameter == null);

public bool HasCollections => !this.HasEfCore || this.Methods.Any(_ => _.IsList && (IsScalarType(_.ItemType) || IsTuple(_.ItemType)));
public bool HasCollections => !this.HasEfCore || this.Methods.Any(_ => (_.IsList || _.IsEnumerable) && (IsScalarType(_.ItemType) || IsTuple(_.ItemType)));

private static IFieldSymbol? GetConnectionField(INamedTypeSymbol classSymbol)
{
Expand Down
6 changes: 5 additions & 1 deletion SqlMarshal/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,15 @@ internal static ITypeSymbol GetUnderlyingType(ITypeSymbol returnType)
return returnType;
}

internal static bool IsList(ITypeSymbol returnType) => returnType.Name == "IList" || returnType.Name == "List";

internal static bool IsEnumerable(ITypeSymbol returnType) => returnType.Name == "IEnumerable";

internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType)
{
if (returnType is INamedTypeSymbol namedTypeSymbol)
{
if (returnType.Name != "IList" && returnType.Name != "List")
if (!IsList(returnType) && !IsEnumerable(returnType))
{
return returnType;
}
Expand Down
2 changes: 1 addition & 1 deletion SqlMarshal/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ private void ProcessMethod(
var signature = $"({string.Join(", ", originalParameters.Select((parameterSymbol, index) => GetParameterDeclaration(methodSymbol, parameterSymbol, index)))})";
var itemType = methodGenerationContext.ItemType;
var getConnection = this.GetConnectionStatement(methodGenerationContext);
var isList = methodGenerationContext.IsList;
var isList = methodGenerationContext.IsList || methodGenerationContext.IsEnumerable;
var isScalarType = IsScalarType(UnwrapNullableType(returnType))
|| returnType.SpecialType == SpecialType.System_Void
|| returnType.Name == "Task";
Expand Down
4 changes: 3 additions & 1 deletion SqlMarshal/MethodGenerationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ internal MethodGenerationContext(ClassGenerationContext classGenerationContext,

internal ITypeSymbol ReturnType => this.MethodSymbol.ReturnType.UnwrapTaskType();

internal bool IsList => this.ItemType != this.ReturnType;
internal bool IsList => IsList(this.ReturnType);

internal bool IsEnumerable => IsEnumerable(this.ReturnType);

internal ITypeSymbol ItemType => UnwrapListItem(this.ReturnType);

Expand Down

0 comments on commit a7b052d

Please sign in to comment.