diff --git a/SqlMarshal.Tests/CrudGenerationTests.cs b/SqlMarshal.Tests/CrudGenerationTests.cs index 7176516..a531f62 100644 --- a/SqlMarshal.Tests/CrudGenerationTests.cs +++ b/SqlMarshal.Tests/CrudGenerationTests.cs @@ -339,6 +339,164 @@ public partial void DeleteById(int id) command.ExecuteNonQuery(); } } +}"; + Assert.AreEqual(expectedOutput, output); + } + + [TestMethod] + public void Update_void_Sync() + { + string source = @" +#nullable enable +namespace Foo +{ + class TestEntity + { + public int Id { get; set; } + + public string Name { get; set; } + + public string Description { get; set; } + } + + [Repository(typeof(TestEntity))] + partial class C + { + private DbConnection connection; + + public partial void Update(int id, string name, string description); + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Collections.Generic; + using System.Data.Common; + using System.Linq; + + partial class C + { + public partial void Update(int id, string name, string description) + { + var connection = this.connection; + using var command = connection.CreateCommand(); + + var idParameter = command.CreateParameter(); + idParameter.ParameterName = ""@id""; + idParameter.Value = id; + + var nameParameter = command.CreateParameter(); + nameParameter.ParameterName = ""@name""; + nameParameter.Value = name == null ? (object)DBNull.Value : name; + + var descriptionParameter = command.CreateParameter(); + descriptionParameter.ParameterName = ""@description""; + descriptionParameter.Value = description == null ? (object)DBNull.Value : description; + + var parameters = new DbParameter[] + { + idParameter, + nameParameter, + descriptionParameter, + }; + + var sqlQuery = @""UPDATE TestEntity SET Name = @name, Description = @description WHERE Id = @id""; + command.CommandText = sqlQuery; + command.Parameters.AddRange(parameters); + command.ExecuteNonQuery(); + } + } +}"; + Assert.AreEqual(expectedOutput, output); + } + + [TestMethod] + public void Insert_void_Sync() + { + string source = @" +#nullable enable +namespace Foo +{ + class TestEntity + { + public int Id { get; set; } + + public string Name { get; set; } + + public string Description { get; set; } + } + + [Repository(typeof(TestEntity))] + partial class C + { + private DbConnection connection; + + public partial void Insert(int id, string name, string description); + } +}"; + string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable); + + Assert.IsNotNull(output); + + var expectedOutput = @"// +// Code generated by Stored Procedures Code Generator. +// Changes may cause incorrect behavior and will be lost if the code is +// regenerated. +// +#nullable enable +#pragma warning disable 1591 + +namespace Foo +{ + using System; + using System.Collections.Generic; + using System.Data.Common; + using System.Linq; + + partial class C + { + public partial void Insert(int id, string name, string description) + { + var connection = this.connection; + using var command = connection.CreateCommand(); + + var idParameter = command.CreateParameter(); + idParameter.ParameterName = ""@id""; + idParameter.Value = id; + + var nameParameter = command.CreateParameter(); + nameParameter.ParameterName = ""@name""; + nameParameter.Value = name == null ? (object)DBNull.Value : name; + + var descriptionParameter = command.CreateParameter(); + descriptionParameter.ParameterName = ""@description""; + descriptionParameter.Value = description == null ? (object)DBNull.Value : description; + + var parameters = new DbParameter[] + { + idParameter, + nameParameter, + descriptionParameter, + }; + + var sqlQuery = @""INSERT INTO TestEntity(Id, Name, Description) VALUES (@id, @name, @description)""; + command.CommandText = sqlQuery; + command.Parameters.AddRange(parameters); + command.ExecuteNonQuery(); + } + } }"; Assert.AreEqual(expectedOutput, output); } diff --git a/SqlMarshal/Extensions.cs b/SqlMarshal/Extensions.cs index 34b1d4e..310fad7 100644 --- a/SqlMarshal/Extensions.cs +++ b/SqlMarshal/Extensions.cs @@ -179,6 +179,17 @@ internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType) internal static IPropertySymbol? FindIdMember(this ITypeSymbol returnType) { - return returnType.GetMembers().OfType().FirstOrDefault(_ => _.Name == "Id"); + return returnType.GetMembers().OfType().FirstOrDefault(FindIdMember); + } + + internal static bool FindIdMember(this IPropertySymbol propertySymbol) + { + return propertySymbol.Name == "Id"; + } + + internal static IPropertySymbol? FindMember(this ITypeSymbol returnType, string parameterName) + { + return returnType.GetMembers().OfType() + .FirstOrDefault(propertySymbol => string.Equals(propertySymbol.Name, parameterName, System.StringComparison.InvariantCultureIgnoreCase)); } } diff --git a/SqlMarshal/Generator.cs b/SqlMarshal/Generator.cs index c041ea9..03336e0 100644 --- a/SqlMarshal/Generator.cs +++ b/SqlMarshal/Generator.cs @@ -7,6 +7,7 @@ namespace SqlMarshal; using System.Collections.Generic; +using System.Data.Common; using System.Linq; using System.Reflection.Metadata; using System.Text; @@ -64,7 +65,11 @@ public RepositoryAttribute(global::System.Type entityType) private static DiagnosticDescriptor SP0002 { get; } = new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true); - private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0002", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true); + private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0003", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true); + + private static DiagnosticDescriptor SP0004 { get; } = new DiagnosticDescriptor("SP0004", "Entity property corresponding to parameter cannot be guessed", "Cannot find property in entity type {0} corresponding to parameter {1}.", "SqlMarshal", DiagnosticSeverity.Error, true); + + private static DiagnosticDescriptor SP0005 { get; } = new DiagnosticDescriptor("SP0005", "Unknown method for generation", "Unknown method {0} for generation.", "SqlMarshal", DiagnosticSeverity.Error, true); /// public void Initialize(GeneratorInitializationContext context) @@ -809,7 +814,7 @@ private void MapResults( builder.Append(" FROM "); builder.Append(entityType.Name); - AppendFitlerById(builder); + AppendFilterById(builder); return builder.ToString(); } @@ -834,13 +839,99 @@ private void MapResults( var builder = new StringBuilder(); builder.Append("DELETE FROM "); builder.Append(entityType.Name); - AppendFitlerById(builder); + AppendFilterById(builder); + return builder.ToString(); + } + + if (canonicalOperationName == "Update") + { + var builder = new StringBuilder(); + builder.Append("UPDATE "); + builder.Append(entityType.Name); + builder.Append(" SET "); + bool first = true; + foreach (var parameter in methodGenerationContext.SqlParameters) + { + if (parameter.Name == "id") + { + continue; + } + + if (!first) + { + builder.Append(", "); + } + + var entityProperty = entityType.FindMember(parameter.Name); + if (entityProperty == null) + { + methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name)); + continue; + } + + builder.Append(entityProperty.Name); + builder.Append(" = "); + builder.Append("@" + NameMapper.MapName(parameter.Name)); + first = false; + } + + AppendFilterById(builder); + return builder.ToString(); + } + + if (canonicalOperationName == "Insert") + { + var builder = new StringBuilder(); + builder.Append("INSERT INTO "); + builder.Append(entityType.Name); + builder.Append("("); + bool first = true; + foreach (var parameter in methodGenerationContext.SqlParameters) + { + if (!first) + { + builder.Append(", "); + } + + var entityProperty = entityType.FindMember(parameter.Name); + if (entityProperty == null) + { + methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name)); + continue; + } + + builder.Append(entityProperty.Name); + first = false; + } + + builder.Append(") VALUES ("); + first = true; + foreach (var parameter in methodGenerationContext.SqlParameters) + { + if (!first) + { + builder.Append(", "); + } + + var entityProperty = entityType.FindMember(parameter.Name); + if (entityProperty == null) + { + methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name)); + continue; + } + + builder.Append("@" + NameMapper.MapName(parameter.Name)); + first = false; + } + + builder.Append(")"); return builder.ToString(); } + methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0005, methodGenerationContext.MethodSymbol.Locations.FirstOrDefault(), methodGenerationContext.MethodSymbol.ToDisplayString())); return null; - void AppendFitlerById(StringBuilder builder) + void AppendFilterById(StringBuilder builder) { builder.Append(" WHERE "); var idMember = entityType.FindIdMember();