diff --git a/SqlMarshal.Tests/CrudGenerationTests.cs b/SqlMarshal.Tests/CrudGenerationTests.cs index 7b3aa74..7176516 100644 --- a/SqlMarshal.Tests/CrudGenerationTests.cs +++ b/SqlMarshal.Tests/CrudGenerationTests.cs @@ -82,6 +82,85 @@ partial class C Assert.AreEqual(expectedOutput, output); } + [TestMethod] + public void FindById_IList_Sync() + { + string source = @" +#nullable enable +namespace Foo +{ + class TestEntity + { + public int Id { get; set; } + + public string Name { get; set; } + } + + [Repository(typeof(TestEntity))] + partial class C + { + private DbConnection connection; + + public partial TestEntity? FindById(int id); + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Collections.Generic; + using System.Data.Common; + using System.Linq; + + partial class C + { + public partial Foo.TestEntity? FindById(int id) + { + var connection = this.connection; + using var command = connection.CreateCommand(); + + var idParameter = command.CreateParameter(); + idParameter.ParameterName = ""@id""; + idParameter.Value = id; + + var parameters = new DbParameter[] + { + idParameter, + }; + + var sqlQuery = @""SELECT Id, Name FROM TestEntity WHERE Id = @id""; + command.CommandText = sqlQuery; + command.Parameters.AddRange(parameters); + using var reader = command.ExecuteReader(System.Data.CommandBehavior.SingleResult | System.Data.CommandBehavior.SingleRow); + if (!reader.Read()) + { + return null; + } + + var result = new TestEntity(); + var value_0 = reader.GetValue(0); + result.Id = (int)value_0; + var value_1 = reader.GetValue(1); + result.Name = value_1 == DBNull.Value ? (string?)null : (string)value_1; + reader.Close(); + return result; + } + } +}"; + Assert.AreEqual(expectedOutput, output); + } + [TestMethod] public void Count_int_Sync() { @@ -136,6 +215,130 @@ public partial int Count() return (int)result!; } } +}"; + Assert.AreEqual(expectedOutput, output); + } + + [TestMethod] + public void DeleteAll_IList_Sync() + { + string source = @" +#nullable enable +namespace Foo +{ + class TestEntity + { + public int Id { get; set; } + + public string Name { get; set; } + } + + [Repository(typeof(TestEntity))] + partial class C + { + private DbConnection connection; + + public partial void DeleteAll(); + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Collections.Generic; + using System.Data.Common; + using System.Linq; + + partial class C + { + public partial void DeleteAll() + { + var connection = this.connection; + using var command = connection.CreateCommand(); + + var sqlQuery = @""DELETE FROM TestEntity""; + command.CommandText = sqlQuery; + command.ExecuteNonQuery(); + } + } +}"; + Assert.AreEqual(expectedOutput, output); + } + + [TestMethod] + public void DeleteById_IList_Sync() + { + string source = @" +#nullable enable +namespace Foo +{ + class TestEntity + { + public int Id { get; set; } + + public string Name { get; set; } + } + + [Repository(typeof(TestEntity))] + partial class C + { + private DbConnection connection; + + public partial void DeleteById(int id); + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Collections.Generic; + using System.Data.Common; + using System.Linq; + + partial class C + { + public partial void DeleteById(int id) + { + var connection = this.connection; + using var command = connection.CreateCommand(); + + var idParameter = command.CreateParameter(); + idParameter.ParameterName = ""@id""; + idParameter.Value = id; + + var parameters = new DbParameter[] + { + idParameter, + }; + + var sqlQuery = @""DELETE FROM TestEntity WHERE Id = @id""; + command.CommandText = sqlQuery; + command.Parameters.AddRange(parameters); + command.ExecuteNonQuery(); + } + } }"; Assert.AreEqual(expectedOutput, output); } diff --git a/SqlMarshal/ClassGenerationContext.cs b/SqlMarshal/ClassGenerationContext.cs index 748d3a3..c4d3601 100644 --- a/SqlMarshal/ClassGenerationContext.cs +++ b/SqlMarshal/ClassGenerationContext.cs @@ -18,13 +18,14 @@ public ClassGenerationContext( List methods, ISymbol attributeSymbol, ISymbol repositoryAttributeSymbol, - NullableContextOptions nullableContextOptions) + GeneratorExecutionContext context) { this.ClassSymbol = classSymbol; this.Methods = methods.Select(_ => new MethodGenerationContext(this, _)).ToList(); this.AttributeSymbol = attributeSymbol; this.RepositoryAttributeSymbol = repositoryAttributeSymbol; - this.NullableContextOptions = nullableContextOptions; + this.GeneratorExecutionContext = context; + this.NullableContextOptions = context.Compilation.Options.NullableContextOptions; this.ConnectionField = GetConnectionField(classSymbol); this.DbContextField = GetContextField(classSymbol); @@ -38,6 +39,8 @@ public ClassGenerationContext( public ISymbol RepositoryAttributeSymbol { get; } + public GeneratorExecutionContext GeneratorExecutionContext { get; } + public NullableContextOptions NullableContextOptions { get; } public bool HasNullableAnnotations => this.NullableContextOptions != NullableContextOptions.Disable; diff --git a/SqlMarshal/Extensions.cs b/SqlMarshal/Extensions.cs index cebf13d..34b1d4e 100644 --- a/SqlMarshal/Extensions.cs +++ b/SqlMarshal/Extensions.cs @@ -7,6 +7,8 @@ namespace SqlMarshal; using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Linq; internal static class Extensions { @@ -174,4 +176,9 @@ internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType) return returnType; } + + internal static IPropertySymbol? FindIdMember(this ITypeSymbol returnType) + { + return returnType.GetMembers().OfType().FirstOrDefault(_ => _.Name == "Id"); + } } diff --git a/SqlMarshal/Generator.cs b/SqlMarshal/Generator.cs index f71f246..c041ea9 100644 --- a/SqlMarshal/Generator.cs +++ b/SqlMarshal/Generator.cs @@ -8,6 +8,7 @@ namespace SqlMarshal; using System.Collections.Generic; using System.Linq; +using System.Reflection.Metadata; using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -59,6 +60,12 @@ public RepositoryAttribute(global::System.Type entityType) } "; + private static DiagnosticDescriptor SP0001 { get; } = new DiagnosticDescriptor("SP0001", "No stored procedure attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true); + + private static DiagnosticDescriptor SP0002 { get; } = new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true); + + private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0002", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true); + /// public void Initialize(GeneratorInitializationContext context) { @@ -78,18 +85,14 @@ public void Execute(GeneratorExecutionContext context) INamedTypeSymbol? attributeSymbol = context.Compilation.GetTypeByMetadataName("SqlMarshalAttribute"); if (attributeSymbol == null) { - context.ReportDiagnostic(Diagnostic.Create( - new DiagnosticDescriptor("SP0001", "No stored procedure attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true), - null)); + context.ReportDiagnostic(Diagnostic.Create(SP0001, null)); return; } INamedTypeSymbol? repositoryAttributeSymbol = context.Compilation.GetTypeByMetadataName("RepositoryAttribute"); if (repositoryAttributeSymbol == null) { - context.ReportDiagnostic(Diagnostic.Create( - new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true), - null)); + context.ReportDiagnostic(Diagnostic.Create(SP0002, null)); return; } @@ -104,7 +107,7 @@ public void Execute(GeneratorExecutionContext context) group.ToList(), attributeSymbol, repositoryAttributeSymbol, - context.Compilation.Options.NullableContextOptions); + context); var sourceCode = this.ProcessClass( generationContext, (INamedTypeSymbol)group.Key!, @@ -112,9 +115,7 @@ public void Execute(GeneratorExecutionContext context) hasNullableAnnotations); if (sourceCode == null) { - context.ReportDiagnostic(Diagnostic.Create( - new DiagnosticDescriptor("SP0002", "No source code generated attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true), - null)); + context.ReportDiagnostic(Diagnostic.Create(SP0002, null)); continue; } @@ -791,6 +792,27 @@ private void MapResults( return builder.ToString(); } + if (canonicalOperationName == "FindById") + { + var builder = new StringBuilder(); + builder.Append("SELECT"); + var properties = entityType.GetMembers().OfType().ToList(); + for (var i = 0; i < properties.Count; i++) + { + builder.Append(" "); + builder.Append(properties[i].Name); + if (i != properties.Count - 1) + { + builder.Append(","); + } + } + + builder.Append(" FROM "); + builder.Append(entityType.Name); + AppendFitlerById(builder); + return builder.ToString(); + } + if (canonicalOperationName == "Count") { var builder = new StringBuilder(); @@ -799,7 +821,40 @@ private void MapResults( return builder.ToString(); } + if (canonicalOperationName == "DeleteAll") + { + var builder = new StringBuilder(); + builder.Append("DELETE FROM "); + builder.Append(entityType.Name); + return builder.ToString(); + } + + if (canonicalOperationName == "DeleteById") + { + var builder = new StringBuilder(); + builder.Append("DELETE FROM "); + builder.Append(entityType.Name); + AppendFitlerById(builder); + return builder.ToString(); + } + return null; + + void AppendFitlerById(StringBuilder builder) + { + builder.Append(" WHERE "); + var idMember = entityType.FindIdMember(); + if (idMember == null) + { + methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic( + Diagnostic.Create(SP0003, methodGenerationContext.MethodSymbol.Locations.FirstOrDefault(), new object[] { entityType.ToDisplayString() })); + return; + } + + builder.Append(idMember.Name); + builder.Append(" = "); + builder.Append("@" + NameMapper.MapName(idMember.Name)); + } } private void ProcessMethod( @@ -835,7 +890,10 @@ private void ProcessMethod( } else if (!methodGenerationContext.IsDataReader) { - returnTypeName += "?"; + if (!returnTypeName.EndsWith("?")) + { + returnTypeName += "?"; + } } } diff --git a/SqlMarshal/NameMapper.cs b/SqlMarshal/NameMapper.cs index 5fdb52e..59799b4 100644 --- a/SqlMarshal/NameMapper.cs +++ b/SqlMarshal/NameMapper.cs @@ -23,6 +23,11 @@ public static string MapName(string parameterName) { var firstname = Regex.Match(parameterName, "[^A-Z]*").Value; var matches = Regex.Matches(parameterName, "[A-Z][^A-Z]*").Cast().Select(_ => _.Value.ToLower()); + if (string.IsNullOrEmpty(firstname)) + { + return string.Join("_", matches); + } + return string.Join("_", new string[] { firstname }.Union(matches)); } }