Skip to content

Commit

Permalink
Add FindAll Crud method
Browse files Browse the repository at this point in the history
  • Loading branch information
kant2002 committed Nov 27, 2024
1 parent 8b2f069 commit b8be291
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 13 deletions.
84 changes: 84 additions & 0 deletions SqlMarshal.Tests/CrudGenerationTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// -----------------------------------------------------------------------
// <copyright file="CrudGenerationTests.cs" company="Andrii Kurdiumov">
// Copyright (c) Andrii Kurdiumov. All rights reserved.
// </copyright>
// -----------------------------------------------------------------------

namespace SqlMarshal.Tests;

using Microsoft.CodeAnalysis;
using Microsoft.VisualStudio.TestTools.UnitTesting;

[TestClass]
public class CrudGenerationTests : CodeGenerationTestBase
{
[TestMethod]
public void FindAll_IList_Sync()
{
string source = @"
#nullable enable
namespace Foo
{
class TestEntity
{
public int Id { get; set; }
public string Name { get; set; }
}
[Repository(typeof(TestEntity))]
partial class C
{
private DbConnection connection;
public partial IList<TestEntity> FindAll();
}
}";
string output = this.GetGeneratedOutput(source, NullableContextOptions.Disable);

Assert.IsNotNull(output);

var expectedOutput = @"// <auto-generated>
// Code generated by Stored Procedures Code Generator.
// Changes may cause incorrect behavior and will be lost if the code is
// regenerated.
// </auto-generated>
#nullable enable
#pragma warning disable 1591
namespace Foo
{
using System;
using System.Collections.Generic;
using System.Data.Common;
using System.Linq;
partial class C
{
public partial IList<Foo.TestEntity> FindAll()
{
var connection = this.connection;
using var command = connection.CreateCommand();
var sqlQuery = @""SELECT Id, Name FROM TestEntity"";
command.CommandText = sqlQuery;
using var reader = command.ExecuteReader();
var result = new List<TestEntity>();
while (reader.Read())
{
var item = new TestEntity();
var value_0 = reader.GetValue(0);
item.Id = (int)value_0;
var value_1 = reader.GetValue(1);
item.Name = value_1 == DBNull.Value ? (string?)null : (string)value_1;
result.Add(item);
}
reader.Close();
return result;
}
}
}";
Assert.AreEqual(expectedOutput, output);
}
}
8 changes: 8 additions & 0 deletions SqlMarshal/ClassGenerationContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@ public ClassGenerationContext(
INamedTypeSymbol classSymbol,
List<IMethodSymbol> methods,
ISymbol attributeSymbol,
ISymbol repositoryAttributeSymbol,
NullableContextOptions nullableContextOptions)
{
this.ClassSymbol = classSymbol;
this.Methods = methods.Select(_ => new MethodGenerationContext(this, _)).ToList();
this.AttributeSymbol = attributeSymbol;
this.RepositoryAttributeSymbol = repositoryAttributeSymbol;
this.NullableContextOptions = nullableContextOptions;

this.ConnectionField = GetConnectionField(classSymbol);
Expand All @@ -34,6 +36,8 @@ public ClassGenerationContext(

public ISymbol AttributeSymbol { get; }

public ISymbol RepositoryAttributeSymbol { get; }

public NullableContextOptions NullableContextOptions { get; }

public bool HasNullableAnnotations => this.NullableContextOptions != NullableContextOptions.Disable;
Expand All @@ -44,6 +48,10 @@ public ClassGenerationContext(

public string DbContextName => this.DbContextField?.Name ?? "dbContext";

public bool IsRepository => this.ClassSymbol.GetAttributes().Any(ad => ad.AttributeClass!.Equals(this.RepositoryAttributeSymbol, SymbolEqualityComparer.Default));

public ITypeSymbol? RepositoryEntityType => (ITypeSymbol?)this.ClassSymbol.GetAttributes().Single(ad => ad.AttributeClass!.Equals(this.RepositoryAttributeSymbol, SymbolEqualityComparer.Default)).ConstructorArguments.ElementAtOrDefault(0).Value;

public bool HasEfCore => this.ConnectionField == null && this.Methods.All(_ => _.ConnectionParameter == null);

public bool HasCollections => !this.HasEfCore || this.Methods.Any(_ => (_.IsList || _.IsEnumerable) && (IsScalarType(_.ItemType) || IsTuple(_.ItemType)));
Expand Down
98 changes: 85 additions & 13 deletions SqlMarshal/Generator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,17 @@ internal sealed class RawSqlAttribute: System.Attribute
{
public RawSqlAttribute() {}
}
[System.AttributeUsage(System.AttributeTargets.Class, AllowMultiple=false)]
internal sealed class RepositoryAttribute: System.Attribute
{
public RepositoryAttribute(global::System.Type entityType)
{
EntityType = entityType;
}
public global::System.Type EntityType { get; }
}
";

/// <inheritdoc/>
Expand Down Expand Up @@ -73,6 +84,15 @@ public void Execute(GeneratorExecutionContext context)
return;
}

INamedTypeSymbol? repositoryAttributeSymbol = context.Compilation.GetTypeByMetadataName("RepositoryAttribute");
if (repositoryAttributeSymbol == null)
{
context.ReportDiagnostic(Diagnostic.Create(
new DiagnosticDescriptor("SP0002", "No repository attribute", "Internal analyzer error.", "Internal", DiagnosticSeverity.Error, true),
null));
return;
}

var hasNullableAnnotations = context.Compilation.Options.NullableContextOptions != NullableContextOptions.Disable;

// Group the fields by class, and generate the source
Expand All @@ -83,6 +103,7 @@ public void Execute(GeneratorExecutionContext context)
(INamedTypeSymbol)group.Key!,
group.ToList(),
attributeSymbol,
repositoryAttributeSymbol,
context.Compilation.Options.NullableContextOptions);
var sourceCode = this.ProcessClass(
generationContext,
Expand Down Expand Up @@ -368,6 +389,20 @@ private static void ExecuteSimpleQuery(
}
}

private static string? GetProcedureName(IMethodSymbol methodSymbol, ISymbol attributeSymbol)
{
AttributeData? attributeData = methodSymbol.GetAttributes().FirstOrDefault(ad => ad.AttributeClass!.Equals(attributeSymbol, SymbolEqualityComparer.Default));
if (attributeData == null)
{
return null;
}

TypedConstant overridenNameOpt = attributeData.NamedArguments.SingleOrDefault(kvp => kvp.Key == "PropertyName").Value;
var procedureNameConstraint = attributeData.ConstructorArguments.ElementAtOrDefault(0);
object? procedureName = procedureNameConstraint.Value;
return (string?)procedureName;
}

private string? ProcessClass(
ClassGenerationContext classGenerationContext,
INamedTypeSymbol classSymbol,
Expand Down Expand Up @@ -423,7 +458,6 @@ namespace {namespaceName}
source,
methodGenerationContext,
methodGenerationContext.MethodSymbol,
attributeSymbol,
hasNullableAnnotations);
}

Expand Down Expand Up @@ -728,11 +762,42 @@ private void MapResults(
}
}

private string? GetQueryForRepositoryMethod(MethodGenerationContext methodGenerationContext)
{
var canonicalOperationName = methodGenerationContext.MethodSymbol.Name;
var entityType = methodGenerationContext.ClassGenerationContext.RepositoryEntityType;
if (entityType is null)
{
return null;
}

if (canonicalOperationName == "FindAll")
{
var builder = new StringBuilder();
builder.Append("SELECT");
var properties = entityType.GetMembers().OfType<IPropertySymbol>().ToList();
for (var i = 0; i < properties.Count; i++)
{
builder.Append(" ");
builder.Append(properties[i].Name);
if (i != properties.Count - 1)
{
builder.Append(",");
}
}

builder.Append(" FROM ");
builder.Append(entityType.Name);
return builder.ToString();
}

return null;
}

private void ProcessMethod(
IndentedStringBuilder source,
MethodGenerationContext methodGenerationContext,
IMethodSymbol methodSymbol,
ISymbol attributeSymbol,
bool hasNullableAnnotations)
{
// get the name and type of the field
Expand All @@ -741,11 +806,7 @@ private void ProcessMethod(
var symbol = (ISymbol)methodSymbol;
var isTask = methodGenerationContext.IsTask;

// get the AutoNotify attribute from the field, and any associated data
AttributeData attributeData = methodSymbol.GetAttributes().Single(ad => ad.AttributeClass!.Equals(attributeSymbol, SymbolEqualityComparer.Default));
TypedConstant overridenNameOpt = attributeData.NamedArguments.SingleOrDefault(kvp => kvp.Key == "PropertyName").Value;
var procedureNameConstraint = attributeData.ConstructorArguments.ElementAtOrDefault(0);
object? procedureName = procedureNameConstraint.Value;
string? procedureName = GetProcedureName(methodSymbol, methodGenerationContext.ClassGenerationContext.AttributeSymbol);
var parameters = methodGenerationContext.SqlParameters;
var originalParameters = methodSymbol.Parameters;

Expand Down Expand Up @@ -803,14 +864,21 @@ private void ProcessMethod(

if (!hasCustomSql)
{
if (parameters.Length == 0)
if (procedureName == null)
{
source.AppendLine($@"var sqlQuery = @""{procedureName}"";");
source.AppendLine($@"var sqlQuery = @""{this.GetQueryForRepositoryMethod(methodGenerationContext)}"";");
}
else
{
string parametersList = string.Join(", ", parameters.Select(parameter => GetParameterPassing(parameter)));
source.AppendLine($@"var sqlQuery = @""{procedureName} {parametersList}"";");
if (parameters.Length == 0)
{
source.AppendLine($@"var sqlQuery = @""{procedureName}"";");
}
else
{
string parametersList = string.Join(", ", parameters.Select(parameter => GetParameterPassing(parameter)));
source.AppendLine($@"var sqlQuery = @""{procedureName} {parametersList}"";");
}
}
}

Expand Down Expand Up @@ -925,8 +993,7 @@ internal class SyntaxReceiver : ISyntaxContextReceiver
public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
{
// any field with at least one attribute is a candidate for property generation
if (context.Node is MethodDeclarationSyntax methodDeclarationSyntax
&& methodDeclarationSyntax.AttributeLists.Count > 0)
if (context.Node is MethodDeclarationSyntax methodDeclarationSyntax)
{
// Get the symbol being declared by the field, and keep it if its annotated
IMethodSymbol? methodSymbol = context.SemanticModel.GetDeclaredSymbol(context.Node) as IMethodSymbol;
Expand All @@ -939,6 +1006,11 @@ public void OnVisitSyntaxNode(GeneratorSyntaxContext context)
{
this.Methods.Add(methodSymbol);
}

if (methodSymbol.ContainingType.GetAttributes().Any(ad => ad.AttributeClass?.ToDisplayString() == "RepositoryAttribute"))
{
this.Methods.Add(methodSymbol);
}
}
}
}
Expand Down

0 comments on commit b8be291

Please sign in to comment.