diff --git a/SqlMarshal.Tests/CrudGenerationTests.cs b/SqlMarshal.Tests/CrudGenerationTests.cs
index 7176516..a531f62 100644
--- a/SqlMarshal.Tests/CrudGenerationTests.cs
+++ b/SqlMarshal.Tests/CrudGenerationTests.cs
@@ -339,6 +339,164 @@ public partial void DeleteById(int id)
command.ExecuteNonQuery();
}
}
+}";
+ Assert.AreEqual(expectedOutput, output);
+ }
+
+ [TestMethod]
+ public void Update_void_Sync()
+ {
+ string source = @"
+#nullable enable
+namespace Foo
+{
+ class TestEntity
+ {
+ public int Id { get; set; }
+
+ public string Name { get; set; }
+
+ public string Description { get; set; }
+ }
+
+ [Repository(typeof(TestEntity))]
+ partial class C
+ {
+ private DbConnection connection;
+
+ public partial void Update(int id, string name, string description);
+ }
+}";
+ string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);
+
+ Assert.IsNotNull(output);
+
+ var expectedOutput = @"//
+// Code generated by Stored Procedures Code Generator.
+// Changes may cause incorrect behavior and will be lost if the code is
+// regenerated.
+//
+#nullable enable
+#pragma warning disable 1591
+
+namespace Foo
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Data.Common;
+ using System.Linq;
+
+ partial class C
+ {
+ public partial void Update(int id, string name, string description)
+ {
+ var connection = this.connection;
+ using var command = connection.CreateCommand();
+
+ var idParameter = command.CreateParameter();
+ idParameter.ParameterName = ""@id"";
+ idParameter.Value = id;
+
+ var nameParameter = command.CreateParameter();
+ nameParameter.ParameterName = ""@name"";
+ nameParameter.Value = name == null ? (object)DBNull.Value : name;
+
+ var descriptionParameter = command.CreateParameter();
+ descriptionParameter.ParameterName = ""@description"";
+ descriptionParameter.Value = description == null ? (object)DBNull.Value : description;
+
+ var parameters = new DbParameter[]
+ {
+ idParameter,
+ nameParameter,
+ descriptionParameter,
+ };
+
+ var sqlQuery = @""UPDATE TestEntity SET Name = @name, Description = @description WHERE Id = @id"";
+ command.CommandText = sqlQuery;
+ command.Parameters.AddRange(parameters);
+ command.ExecuteNonQuery();
+ }
+ }
+}";
+ Assert.AreEqual(expectedOutput, output);
+ }
+
+ [TestMethod]
+ public void Insert_void_Sync()
+ {
+ string source = @"
+#nullable enable
+namespace Foo
+{
+ class TestEntity
+ {
+ public int Id { get; set; }
+
+ public string Name { get; set; }
+
+ public string Description { get; set; }
+ }
+
+ [Repository(typeof(TestEntity))]
+ partial class C
+ {
+ private DbConnection connection;
+
+ public partial void Insert(int id, string name, string description);
+ }
+}";
+ string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);
+
+ Assert.IsNotNull(output);
+
+ var expectedOutput = @"//
+// Code generated by Stored Procedures Code Generator.
+// Changes may cause incorrect behavior and will be lost if the code is
+// regenerated.
+//
+#nullable enable
+#pragma warning disable 1591
+
+namespace Foo
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Data.Common;
+ using System.Linq;
+
+ partial class C
+ {
+ public partial void Insert(int id, string name, string description)
+ {
+ var connection = this.connection;
+ using var command = connection.CreateCommand();
+
+ var idParameter = command.CreateParameter();
+ idParameter.ParameterName = ""@id"";
+ idParameter.Value = id;
+
+ var nameParameter = command.CreateParameter();
+ nameParameter.ParameterName = ""@name"";
+ nameParameter.Value = name == null ? (object)DBNull.Value : name;
+
+ var descriptionParameter = command.CreateParameter();
+ descriptionParameter.ParameterName = ""@description"";
+ descriptionParameter.Value = description == null ? (object)DBNull.Value : description;
+
+ var parameters = new DbParameter[]
+ {
+ idParameter,
+ nameParameter,
+ descriptionParameter,
+ };
+
+ var sqlQuery = @""INSERT INTO TestEntity(Id, Name, Description) VALUES (@id, @name, @description)"";
+ command.CommandText = sqlQuery;
+ command.Parameters.AddRange(parameters);
+ command.ExecuteNonQuery();
+ }
+ }
}";
Assert.AreEqual(expectedOutput, output);
}
diff --git a/SqlMarshal/Extensions.cs b/SqlMarshal/Extensions.cs
index 34b1d4e..310fad7 100644
--- a/SqlMarshal/Extensions.cs
+++ b/SqlMarshal/Extensions.cs
@@ -179,6 +179,17 @@ internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType)
internal static IPropertySymbol? FindIdMember(this ITypeSymbol returnType)
{
- return returnType.GetMembers().OfType().FirstOrDefault(_ => _.Name == "Id");
+ return returnType.GetMembers().OfType().FirstOrDefault(FindIdMember);
+ }
+
+ internal static bool FindIdMember(this IPropertySymbol propertySymbol)
+ {
+ return propertySymbol.Name == "Id";
+ }
+
+ internal static IPropertySymbol? FindMember(this ITypeSymbol returnType, string parameterName)
+ {
+ return returnType.GetMembers().OfType()
+ .FirstOrDefault(propertySymbol => string.Equals(propertySymbol.Name, parameterName, System.StringComparison.InvariantCultureIgnoreCase));
}
}
diff --git a/SqlMarshal/Generator.cs b/SqlMarshal/Generator.cs
index c041ea9..03336e0 100644
--- a/SqlMarshal/Generator.cs
+++ b/SqlMarshal/Generator.cs
@@ -7,6 +7,7 @@
namespace SqlMarshal;
using System.Collections.Generic;
+using System.Data.Common;
using System.Linq;
using System.Reflection.Metadata;
using System.Text;
@@ -64,7 +65,11 @@ public RepositoryAttribute(global::System.Type entityType)
private static DiagnosticDescriptor SP0002 { get; } = new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true);
- private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0002", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true);
+ private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0003", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true);
+
+ private static DiagnosticDescriptor SP0004 { get; } = new DiagnosticDescriptor("SP0004", "Entity property corresponding to parameter cannot be guessed", "Cannot find property in entity type {0} corresponding to parameter {1}.", "SqlMarshal", DiagnosticSeverity.Error, true);
+
+ private static DiagnosticDescriptor SP0005 { get; } = new DiagnosticDescriptor("SP0005", "Unknown method for generation", "Unknown method {0} for generation.", "SqlMarshal", DiagnosticSeverity.Error, true);
///
public void Initialize(GeneratorInitializationContext context)
@@ -809,7 +814,7 @@ private void MapResults(
builder.Append(" FROM ");
builder.Append(entityType.Name);
- AppendFitlerById(builder);
+ AppendFilterById(builder);
return builder.ToString();
}
@@ -834,13 +839,99 @@ private void MapResults(
var builder = new StringBuilder();
builder.Append("DELETE FROM ");
builder.Append(entityType.Name);
- AppendFitlerById(builder);
+ AppendFilterById(builder);
+ return builder.ToString();
+ }
+
+ if (canonicalOperationName == "Update")
+ {
+ var builder = new StringBuilder();
+ builder.Append("UPDATE ");
+ builder.Append(entityType.Name);
+ builder.Append(" SET ");
+ bool first = true;
+ foreach (var parameter in methodGenerationContext.SqlParameters)
+ {
+ if (parameter.Name == "id")
+ {
+ continue;
+ }
+
+ if (!first)
+ {
+ builder.Append(", ");
+ }
+
+ var entityProperty = entityType.FindMember(parameter.Name);
+ if (entityProperty == null)
+ {
+ methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name));
+ continue;
+ }
+
+ builder.Append(entityProperty.Name);
+ builder.Append(" = ");
+ builder.Append("@" + NameMapper.MapName(parameter.Name));
+ first = false;
+ }
+
+ AppendFilterById(builder);
+ return builder.ToString();
+ }
+
+ if (canonicalOperationName == "Insert")
+ {
+ var builder = new StringBuilder();
+ builder.Append("INSERT INTO ");
+ builder.Append(entityType.Name);
+ builder.Append("(");
+ bool first = true;
+ foreach (var parameter in methodGenerationContext.SqlParameters)
+ {
+ if (!first)
+ {
+ builder.Append(", ");
+ }
+
+ var entityProperty = entityType.FindMember(parameter.Name);
+ if (entityProperty == null)
+ {
+ methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name));
+ continue;
+ }
+
+ builder.Append(entityProperty.Name);
+ first = false;
+ }
+
+ builder.Append(") VALUES (");
+ first = true;
+ foreach (var parameter in methodGenerationContext.SqlParameters)
+ {
+ if (!first)
+ {
+ builder.Append(", ");
+ }
+
+ var entityProperty = entityType.FindMember(parameter.Name);
+ if (entityProperty == null)
+ {
+ methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name));
+ continue;
+ }
+
+ builder.Append("@" + NameMapper.MapName(parameter.Name));
+ first = false;
+ }
+
+ builder.Append(")");
return builder.ToString();
}
+ methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0005, methodGenerationContext.MethodSymbol.Locations.FirstOrDefault(), methodGenerationContext.MethodSymbol.ToDisplayString()));
return null;
- void AppendFitlerById(StringBuilder builder)
+ void AppendFilterById(StringBuilder builder)
{
builder.Append(" WHERE ");
var idMember = entityType.FindIdMember();