Skip to content

Commit

Permalink
Fixed bug when nullable enabled, and dbset declared as nullable.
Browse files Browse the repository at this point in the history
In DbContext that does not make, sense, so use !
  • Loading branch information
kant2002 committed Aug 19, 2024
1 parent 20ed18e commit 544287f
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 5 deletions.
3 changes: 3 additions & 0 deletions SqlMarshal.CompilationTests/DbContextManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,7 @@ public DbContextManager(PersonDbContext context)

[SqlMarshal("persons_by_id")]
public partial PersonDbContext.Person GetPersonById(int personId);

[SqlMarshal("users_list")]
public partial IList<PersonDbContext.User> GetUsers();
}
11 changes: 11 additions & 0 deletions SqlMarshal.CompilationTests/PersonDbContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ public PersonDbContext(DbContextOptions<PersonDbContext> options)

public DbSet<Person> Persons { get; set; } = null!;

public DbSet<User> Users { get; set; } = null!;

internal class Person
{
[Column("person_id")]
Expand All @@ -26,5 +28,14 @@ internal class Person
[Column("person_name")]
public string? PersonName { get; set; }
}

internal class User
{
[Column("user_id")]
public int UserId { get; set; }

[Column("user_name")]
public string? UserName { get; set; }
}
}
}
12 changes: 12 additions & 0 deletions SqlMarshal.CompilationTests/sqlmarshal_sample.sql
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,15 @@ begin
end
GO

CREATE TABLE [user] (
user_id int not null identity primary key,
user_name nvarchar(100) null
)
GO



CREATE OR ALTER PROCEDURE users_list
AS
SELECT * from [user]
GO
70 changes: 70 additions & 0 deletions SqlMarshal.Tests/StoredProcedureGenerationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,76 @@ public partial IList<PersonItem> M(string clientId, string personId)
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void DbSetNameFoundFromClass2_WithNullable()
{
string source = @"
namespace Foo
{
public partial class CustomDbContext : DbContext
{
public virtual DbSet<Item>? Items { get; set; } = null!;
public virtual DbSet<PersonItem>? Persons { get; set; } = null!;
}
class C
{
private readonly CustomDbContext context;
[SqlMarshal(""sp_TestSP"")]
public partial IList<PersonItem> M(string? clientId, string personId)
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Enable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591
namespace Foo
{
using System;
using System.Data.Common;
using System.Linq;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Storage;
partial class C
{
public partial IList<PersonItem> M(string? clientId, string personId)
{
var connection = this.context.Database.GetDbConnection();
using var command = connection.CreateCommand();
var clientIdParameter = command.CreateParameter();
clientIdParameter.ParameterName = ""@client_id"";
clientIdParameter.Value = clientId == null ? (object)DBNull.Value : clientId;
var personIdParameter = command.CreateParameter();
personIdParameter.ParameterName = ""@person_id"";
personIdParameter.Value = personId;
var parameters = new DbParameter[]
{
clientIdParameter,
personIdParameter,
};
var sqlQuery = @""sp_TestSP @client_id, @person_id"";
var result = this.context.Persons!.FromSqlRaw(sqlQuery, parameters).ToList();
return result;
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void NonReferenceParameterPassedDirectlyToStoredProcedure()
{
Expand Down
12 changes: 7 additions & 5 deletions SqlMarshal/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ private static string GetAccessibility(Accessibility a)
};
}

private static ISymbol? GetDbSetField(IFieldSymbol? dbContextSymbol, ITypeSymbol itemTypeSymbol)
private static IPropertySymbol? GetDbSetField(IFieldSymbol? dbContextSymbol, ITypeSymbol itemTypeSymbol)
{
if (dbContextSymbol == null)
{
Expand Down Expand Up @@ -695,17 +695,19 @@ private void MapResults(
{
var dbContextSymbol = methodGenerationContext.ClassGenerationContext.DbContextField;
var contextName = methodGenerationContext.ClassGenerationContext.DbContextName;
var itemTypeProperty = GetDbSetField(dbContextSymbol, itemType)?.Name ?? itemType.Name + "s";
var dbsetField = GetDbSetField(dbContextSymbol, itemType);
var itemTypeProperty = dbsetField?.Name ?? itemType.Name + "s";
var nullableAnnotations = dbsetField?.NullableAnnotation == NullableAnnotation.Annotated && methodGenerationContext.ClassGenerationContext.NullableContextOptions.AnnotationsEnabled() ? "!" : string.Empty;
if (isTask)
{
if (isList)
{
source.AppendLine($"var result = await this.{contextName}.{itemTypeProperty}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).ToListAsync({cancellationToken}).ConfigureAwait(false);");
source.AppendLine($"var result = await this.{contextName}.{itemTypeProperty}{nullableAnnotations}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).ToListAsync({cancellationToken}).ConfigureAwait(false);");
}
else
{
source.AppendLine($"{itemType} result = null!;");
source.AppendLine($"var asyncEnumerable = this.{contextName}.{itemTypeProperty}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).AsAsyncEnumerable();");
source.AppendLine($"var asyncEnumerable = this.{contextName}.{itemTypeProperty}{nullableAnnotations}.FromSqlRaw(sqlQuery{(parameters.Length == 0 ? string.Empty : ", parameters")}).AsAsyncEnumerable();");
source.AppendLine($"await foreach (var current in asyncEnumerable)");
source.AppendLine("{");
source.PushIndent();
Expand All @@ -721,7 +723,7 @@ private void MapResults(
string materializeResults = isList
? "ToList"
: methodGenerationContext.ClassGenerationContext.NullableContextOptions == NullableContextOptions.Enable ? "AsEnumerable().First" : "AsEnumerable().FirstOrDefault";
source.AppendLine($"var result = this.{contextName}.{itemTypeProperty}.FromSqlRaw(sqlQuery{parameterString}).{materializeResults}();");
source.AppendLine($"var result = this.{contextName}.{itemTypeProperty}{nullableAnnotations}.FromSqlRaw(sqlQuery{parameterString}).{materializeResults}();");
}
}
}
Expand Down

0 comments on commit 544287f

Please sign in to comment.