diff --git a/SqlMarshal.Tests/CrudGenerationTests.cs b/SqlMarshal.Tests/CrudGenerationTests.cs
index 7b3aa74..7176516 100644
--- a/SqlMarshal.Tests/CrudGenerationTests.cs
+++ b/SqlMarshal.Tests/CrudGenerationTests.cs
@@ -82,6 +82,85 @@ partial class C
Assert.AreEqual(expectedOutput, output);
}
+ [TestMethod]
+ public void FindById_IList_Sync()
+ {
+ string source = @"
+#nullable enable
+namespace Foo
+{
+ class TestEntity
+ {
+ public int Id { get; set; }
+
+ public string Name { get; set; }
+ }
+
+ [Repository(typeof(TestEntity))]
+ partial class C
+ {
+ private DbConnection connection;
+
+ public partial TestEntity? FindById(int id);
+ }
+}";
+ string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);
+
+ Assert.IsNotNull(output);
+
+ var expectedOutput = @"//
+// Code generated by Stored Procedures Code Generator.
+// Changes may cause incorrect behavior and will be lost if the code is
+// regenerated.
+//
+#nullable enable
+#pragma warning disable 1591
+
+namespace Foo
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Data.Common;
+ using System.Linq;
+
+ partial class C
+ {
+ public partial Foo.TestEntity? FindById(int id)
+ {
+ var connection = this.connection;
+ using var command = connection.CreateCommand();
+
+ var idParameter = command.CreateParameter();
+ idParameter.ParameterName = ""@id"";
+ idParameter.Value = id;
+
+ var parameters = new DbParameter[]
+ {
+ idParameter,
+ };
+
+ var sqlQuery = @""SELECT Id, Name FROM TestEntity WHERE Id = @id"";
+ command.CommandText = sqlQuery;
+ command.Parameters.AddRange(parameters);
+ using var reader = command.ExecuteReader(System.Data.CommandBehavior.SingleResult | System.Data.CommandBehavior.SingleRow);
+ if (!reader.Read())
+ {
+ return null;
+ }
+
+ var result = new TestEntity();
+ var value_0 = reader.GetValue(0);
+ result.Id = (int)value_0;
+ var value_1 = reader.GetValue(1);
+ result.Name = value_1 == DBNull.Value ? (string?)null : (string)value_1;
+ reader.Close();
+ return result;
+ }
+ }
+}";
+ Assert.AreEqual(expectedOutput, output);
+ }
+
[TestMethod]
public void Count_int_Sync()
{
@@ -136,6 +215,130 @@ public partial int Count()
return (int)result!;
}
}
+}";
+ Assert.AreEqual(expectedOutput, output);
+ }
+
+ [TestMethod]
+ public void DeleteAll_IList_Sync()
+ {
+ string source = @"
+#nullable enable
+namespace Foo
+{
+ class TestEntity
+ {
+ public int Id { get; set; }
+
+ public string Name { get; set; }
+ }
+
+ [Repository(typeof(TestEntity))]
+ partial class C
+ {
+ private DbConnection connection;
+
+ public partial void DeleteAll();
+ }
+}";
+ string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);
+
+ Assert.IsNotNull(output);
+
+ var expectedOutput = @"//
+// Code generated by Stored Procedures Code Generator.
+// Changes may cause incorrect behavior and will be lost if the code is
+// regenerated.
+//
+#nullable enable
+#pragma warning disable 1591
+
+namespace Foo
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Data.Common;
+ using System.Linq;
+
+ partial class C
+ {
+ public partial void DeleteAll()
+ {
+ var connection = this.connection;
+ using var command = connection.CreateCommand();
+
+ var sqlQuery = @""DELETE FROM TestEntity"";
+ command.CommandText = sqlQuery;
+ command.ExecuteNonQuery();
+ }
+ }
+}";
+ Assert.AreEqual(expectedOutput, output);
+ }
+
+ [TestMethod]
+ public void DeleteById_IList_Sync()
+ {
+ string source = @"
+#nullable enable
+namespace Foo
+{
+ class TestEntity
+ {
+ public int Id { get; set; }
+
+ public string Name { get; set; }
+ }
+
+ [Repository(typeof(TestEntity))]
+ partial class C
+ {
+ private DbConnection connection;
+
+ public partial void DeleteById(int id);
+ }
+}";
+ string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);
+
+ Assert.IsNotNull(output);
+
+ var expectedOutput = @"//
+// Code generated by Stored Procedures Code Generator.
+// Changes may cause incorrect behavior and will be lost if the code is
+// regenerated.
+//
+#nullable enable
+#pragma warning disable 1591
+
+namespace Foo
+{
+ using System;
+ using System.Collections.Generic;
+ using System.Data.Common;
+ using System.Linq;
+
+ partial class C
+ {
+ public partial void DeleteById(int id)
+ {
+ var connection = this.connection;
+ using var command = connection.CreateCommand();
+
+ var idParameter = command.CreateParameter();
+ idParameter.ParameterName = ""@id"";
+ idParameter.Value = id;
+
+ var parameters = new DbParameter[]
+ {
+ idParameter,
+ };
+
+ var sqlQuery = @""DELETE FROM TestEntity WHERE Id = @id"";
+ command.CommandText = sqlQuery;
+ command.Parameters.AddRange(parameters);
+ command.ExecuteNonQuery();
+ }
+ }
}";
Assert.AreEqual(expectedOutput, output);
}
diff --git a/SqlMarshal/ClassGenerationContext.cs b/SqlMarshal/ClassGenerationContext.cs
index 748d3a3..c4d3601 100644
--- a/SqlMarshal/ClassGenerationContext.cs
+++ b/SqlMarshal/ClassGenerationContext.cs
@@ -18,13 +18,14 @@ public ClassGenerationContext(
List methods,
ISymbol attributeSymbol,
ISymbol repositoryAttributeSymbol,
- NullableContextOptions nullableContextOptions)
+ GeneratorExecutionContext context)
{
this.ClassSymbol = classSymbol;
this.Methods = methods.Select(_ => new MethodGenerationContext(this, _)).ToList();
this.AttributeSymbol = attributeSymbol;
this.RepositoryAttributeSymbol = repositoryAttributeSymbol;
- this.NullableContextOptions = nullableContextOptions;
+ this.GeneratorExecutionContext = context;
+ this.NullableContextOptions = context.Compilation.Options.NullableContextOptions;
this.ConnectionField = GetConnectionField(classSymbol);
this.DbContextField = GetContextField(classSymbol);
@@ -38,6 +39,8 @@ public ClassGenerationContext(
public ISymbol RepositoryAttributeSymbol { get; }
+ public GeneratorExecutionContext GeneratorExecutionContext { get; }
+
public NullableContextOptions NullableContextOptions { get; }
public bool HasNullableAnnotations => this.NullableContextOptions != NullableContextOptions.Disable;
diff --git a/SqlMarshal/Extensions.cs b/SqlMarshal/Extensions.cs
index cebf13d..34b1d4e 100644
--- a/SqlMarshal/Extensions.cs
+++ b/SqlMarshal/Extensions.cs
@@ -7,6 +7,8 @@
namespace SqlMarshal;
using Microsoft.CodeAnalysis;
+using Microsoft.CodeAnalysis.CSharp.Syntax;
+using System.Linq;
internal static class Extensions
{
@@ -174,4 +176,9 @@ internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType)
return returnType;
}
+
+ internal static IPropertySymbol? FindIdMember(this ITypeSymbol returnType)
+ {
+ return returnType.GetMembers().OfType().FirstOrDefault(_ => _.Name == "Id");
+ }
}
diff --git a/SqlMarshal/Generator.cs b/SqlMarshal/Generator.cs
index f71f246..c041ea9 100644
--- a/SqlMarshal/Generator.cs
+++ b/SqlMarshal/Generator.cs
@@ -8,6 +8,7 @@ namespace SqlMarshal;
using System.Collections.Generic;
using System.Linq;
+using System.Reflection.Metadata;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
@@ -59,6 +60,12 @@ public RepositoryAttribute(global::System.Type entityType)
}
";
+ private static DiagnosticDescriptor SP0001 { get; } = new DiagnosticDescriptor("SP0001", "No stored procedure attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true);
+
+ private static DiagnosticDescriptor SP0002 { get; } = new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true);
+
+ private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0002", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true);
+
///
public void Initialize(GeneratorInitializationContext context)
{
@@ -78,18 +85,14 @@ public void Execute(GeneratorExecutionContext context)
INamedTypeSymbol? attributeSymbol = context.Compilation.GetTypeByMetadataName("SqlMarshalAttribute");
if (attributeSymbol == null)
{
- context.ReportDiagnostic(Diagnostic.Create(
- new DiagnosticDescriptor("SP0001", "No stored procedure attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true),
- null));
+ context.ReportDiagnostic(Diagnostic.Create(SP0001, null));
return;
}
INamedTypeSymbol? repositoryAttributeSymbol = context.Compilation.GetTypeByMetadataName("RepositoryAttribute");
if (repositoryAttributeSymbol == null)
{
- context.ReportDiagnostic(Diagnostic.Create(
- new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true),
- null));
+ context.ReportDiagnostic(Diagnostic.Create(SP0002, null));
return;
}
@@ -104,7 +107,7 @@ public void Execute(GeneratorExecutionContext context)
group.ToList(),
attributeSymbol,
repositoryAttributeSymbol,
- context.Compilation.Options.NullableContextOptions);
+ context);
var sourceCode = this.ProcessClass(
generationContext,
(INamedTypeSymbol)group.Key!,
@@ -112,9 +115,7 @@ public void Execute(GeneratorExecutionContext context)
hasNullableAnnotations);
if (sourceCode == null)
{
- context.ReportDiagnostic(Diagnostic.Create(
- new DiagnosticDescriptor("SP0002", "No source code generated attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true),
- null));
+ context.ReportDiagnostic(Diagnostic.Create(SP0002, null));
continue;
}
@@ -791,6 +792,27 @@ private void MapResults(
return builder.ToString();
}
+ if (canonicalOperationName == "FindById")
+ {
+ var builder = new StringBuilder();
+ builder.Append("SELECT");
+ var properties = entityType.GetMembers().OfType().ToList();
+ for (var i = 0; i < properties.Count; i++)
+ {
+ builder.Append(" ");
+ builder.Append(properties[i].Name);
+ if (i != properties.Count - 1)
+ {
+ builder.Append(",");
+ }
+ }
+
+ builder.Append(" FROM ");
+ builder.Append(entityType.Name);
+ AppendFitlerById(builder);
+ return builder.ToString();
+ }
+
if (canonicalOperationName == "Count")
{
var builder = new StringBuilder();
@@ -799,7 +821,40 @@ private void MapResults(
return builder.ToString();
}
+ if (canonicalOperationName == "DeleteAll")
+ {
+ var builder = new StringBuilder();
+ builder.Append("DELETE FROM ");
+ builder.Append(entityType.Name);
+ return builder.ToString();
+ }
+
+ if (canonicalOperationName == "DeleteById")
+ {
+ var builder = new StringBuilder();
+ builder.Append("DELETE FROM ");
+ builder.Append(entityType.Name);
+ AppendFitlerById(builder);
+ return builder.ToString();
+ }
+
return null;
+
+ void AppendFitlerById(StringBuilder builder)
+ {
+ builder.Append(" WHERE ");
+ var idMember = entityType.FindIdMember();
+ if (idMember == null)
+ {
+ methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(
+ Diagnostic.Create(SP0003, methodGenerationContext.MethodSymbol.Locations.FirstOrDefault(), new object[] { entityType.ToDisplayString() }));
+ return;
+ }
+
+ builder.Append(idMember.Name);
+ builder.Append(" = ");
+ builder.Append("@" + NameMapper.MapName(idMember.Name));
+ }
}
private void ProcessMethod(
@@ -835,7 +890,10 @@ private void ProcessMethod(
}
else if (!methodGenerationContext.IsDataReader)
{
- returnTypeName += "?";
+ if (!returnTypeName.EndsWith("?"))
+ {
+ returnTypeName += "?";
+ }
}
}
diff --git a/SqlMarshal/NameMapper.cs b/SqlMarshal/NameMapper.cs
index 5fdb52e..59799b4 100644
--- a/SqlMarshal/NameMapper.cs
+++ b/SqlMarshal/NameMapper.cs
@@ -23,6 +23,11 @@ public static string MapName(string parameterName)
{
var firstname = Regex.Match(parameterName, "[^A-Z]*").Value;
var matches = Regex.Matches(parameterName, "[A-Z][^A-Z]*").Cast().Select(_ => _.Value.ToLower());
+ if (string.IsNullOrEmpty(firstname))
+ {
+ return string.Join("_", matches);
+ }
+
return string.Join("_", new string[] { firstname }.Union(matches));
}
}