Skip to content

Commit

Permalink
Add insert and update generation
Browse files Browse the repository at this point in the history
  • Loading branch information
kant2002 committed Nov 30, 2024
1 parent 622c72f commit 7a97e80
Show file tree
Hide file tree
Showing 3 changed files with 265 additions and 5 deletions.
158 changes: 158 additions & 0 deletions SqlMarshal.Tests/CrudGenerationTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,164 @@ public partial void DeleteById(int id)
command.ExecuteNonQuery();
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void Update_void_Sync()
{
string source = @"
#nullable enable
namespace Foo
{
class TestEntity
{
public int Id { get; set; }
public string Name { get; set; }
public string Description { get; set; }
}
[Repository(typeof(TestEntity))]
partial class C
{
private DbConnection connection;
public partial void Update(int id, string name, string description);
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591
namespace Foo
{
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
partial class C
{
public partial void Update(int id, string name, string description)
{
var connection = this.connection;
using var command = connection.CreateCommand();
var idParameter = command.CreateParameter();
idParameter.ParameterName = ""@id"";
idParameter.Value = id;
var nameParameter = command.CreateParameter();
nameParameter.ParameterName = ""@name"";
nameParameter.Value = name == null ? (object)DBNull.Value : name;
var descriptionParameter = command.CreateParameter();
descriptionParameter.ParameterName = ""@description"";
descriptionParameter.Value = description == null ? (object)DBNull.Value : description;
var parameters = new DbParameter[]
{
idParameter,
nameParameter,
descriptionParameter,
};
var sqlQuery = @""UPDATE TestEntity SET Name = @name, Description = @description WHERE Id = @id"";
command.CommandText = sqlQuery;
command.Parameters.AddRange(parameters);
command.ExecuteNonQuery();
}
}
}";
Assert.AreEqual(expectedOutput, output);
}

[TestMethod]
public void Insert_void_Sync()
{
string source = @"
#nullable enable
namespace Foo
{
class TestEntity
{
public int Id { get; set; }
public string Name { get; set; }
public string Description { get; set; }
}
[Repository(typeof(TestEntity))]
partial class C
{
private DbConnection connection;
public partial void Insert(int id, string name, string description);
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591
namespace Foo
{
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
partial class C
{
public partial void Insert(int id, string name, string description)
{
var connection = this.connection;
using var command = connection.CreateCommand();
var idParameter = command.CreateParameter();
idParameter.ParameterName = ""@id"";
idParameter.Value = id;
var nameParameter = command.CreateParameter();
nameParameter.ParameterName = ""@name"";
nameParameter.Value = name == null ? (object)DBNull.Value : name;
var descriptionParameter = command.CreateParameter();
descriptionParameter.ParameterName = ""@description"";
descriptionParameter.Value = description == null ? (object)DBNull.Value : description;
var parameters = new DbParameter[]
{
idParameter,
nameParameter,
descriptionParameter,
};
var sqlQuery = @""INSERT INTO TestEntity(Id, Name, Description) VALUES (@id, @name, @description)"";
command.CommandText = sqlQuery;
command.Parameters.AddRange(parameters);
command.ExecuteNonQuery();
}
}
}";
Assert.AreEqual(expectedOutput, output);
}
Expand Down
13 changes: 12 additions & 1 deletion SqlMarshal/Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,17 @@ internal static ITypeSymbol UnwrapListItem(ITypeSymbol returnType)

internal static IPropertySymbol? FindIdMember(this ITypeSymbol returnType)
{
return returnType.GetMembers().OfType<IPropertySymbol>().FirstOrDefault(_ => _.Name == "Id");
return returnType.GetMembers().OfType<IPropertySymbol>().FirstOrDefault(FindIdMember);
}

internal static bool FindIdMember(this IPropertySymbol propertySymbol)
{
return propertySymbol.Name == "Id";
}

internal static IPropertySymbol? FindMember(this ITypeSymbol returnType, string parameterName)
{
return returnType.GetMembers().OfType<IPropertySymbol>()
.FirstOrDefault(propertySymbol => string.Equals(propertySymbol.Name, parameterName, System.StringComparison.InvariantCultureIgnoreCase));
}
}
99 changes: 95 additions & 4 deletions SqlMarshal/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
namespace SqlMarshal;

using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
using System.Reflection.Metadata;
using System.Text;
Expand Down Expand Up @@ -64,7 +65,11 @@ public RepositoryAttribute(global::System.Type entityType)

private static DiagnosticDescriptor SP0002 { get; } = new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true);

private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0002", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true);
private static DiagnosticDescriptor SP0003 { get; } = new DiagnosticDescriptor("SP0003", "Id property cannot be guessed", "Cannot find id property for entity type {0}.", "SqlMarshal", DiagnosticSeverity.Error, true);

private static DiagnosticDescriptor SP0004 { get; } = new DiagnosticDescriptor("SP0004", "Entity property corresponding to parameter cannot be guessed", "Cannot find property in entity type {0} corresponding to parameter {1}.", "SqlMarshal", DiagnosticSeverity.Error, true);

private static DiagnosticDescriptor SP0005 { get; } = new DiagnosticDescriptor("SP0005", "Unknown method for generation", "Unknown method {0} for generation.", "SqlMarshal", DiagnosticSeverity.Error, true);

/// <inheritdoc/>
public void Initialize(GeneratorInitializationContext context)
Expand Down Expand Up @@ -809,7 +814,7 @@ private void MapResults(

builder.Append(" FROM ");
builder.Append(entityType.Name);
AppendFitlerById(builder);
AppendFilterById(builder);
return builder.ToString();
}

Expand All @@ -834,13 +839,99 @@ private void MapResults(
var builder = new StringBuilder();
builder.Append("DELETE FROM ");
builder.Append(entityType.Name);
AppendFitlerById(builder);
AppendFilterById(builder);
return builder.ToString();
}

if (canonicalOperationName == "Update")
{
var builder = new StringBuilder();
builder.Append("UPDATE ");
builder.Append(entityType.Name);
builder.Append(" SET ");
bool first = true;
foreach (var parameter in methodGenerationContext.SqlParameters)
{
if (parameter.Name == "id")
{
continue;
}

if (!first)
{
builder.Append(", ");
}

var entityProperty = entityType.FindMember(parameter.Name);
if (entityProperty == null)
{
methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name));
continue;
}

builder.Append(entityProperty.Name);
builder.Append(" = ");
builder.Append("@" + NameMapper.MapName(parameter.Name));
first = false;
}

AppendFilterById(builder);
return builder.ToString();
}

if (canonicalOperationName == "Insert")
{
var builder = new StringBuilder();
builder.Append("INSERT INTO ");
builder.Append(entityType.Name);
builder.Append("(");
bool first = true;
foreach (var parameter in methodGenerationContext.SqlParameters)
{
if (!first)
{
builder.Append(", ");
}

var entityProperty = entityType.FindMember(parameter.Name);
if (entityProperty == null)
{
methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name));
continue;
}

builder.Append(entityProperty.Name);
first = false;
}

builder.Append(") VALUES (");
first = true;
foreach (var parameter in methodGenerationContext.SqlParameters)
{
if (!first)
{
builder.Append(", ");
}

var entityProperty = entityType.FindMember(parameter.Name);
if (entityProperty == null)
{
methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0004, parameter.Locations.FirstOrDefault(), entityType.ToDisplayString(), parameter.Name));
continue;
}

builder.Append("@" + NameMapper.MapName(parameter.Name));
first = false;
}

builder.Append(")");
return builder.ToString();
}

methodGenerationContext.ClassGenerationContext.GeneratorExecutionContext.ReportDiagnostic(Diagnostic.Create(SP0005, methodGenerationContext.MethodSymbol.Locations.FirstOrDefault(), methodGenerationContext.MethodSymbol.ToDisplayString()));
return null;

void AppendFitlerById(StringBuilder builder)
void AppendFilterById(StringBuilder builder)
{
builder.Append(" WHERE ");
var idMember = entityType.FindIdMember();
Expand Down

0 comments on commit 7a97e80

Please sign in to comment.