Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions sdk/Sdk.Generators/Extensions/IMethodSymbolExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the MIT License. See License.txt in the project root for license information.

using Microsoft.CodeAnalysis;

namespace Microsoft.Azure.Functions.Worker.Sdk.Generators
{
internal static class IMethodSymbolExtensions
{
/// <summary>
/// Determines the visibility of an azure function method.
/// The visibility is determined by the following rules:
/// 1. If the method is public, and all containing types are public, return Public
/// 2. If the method is public, but one or more containing types are not public, return PublicButContainingTypeNotVisible
/// 3. If the method is not public, return NotPublic
/// </summary>
/// <param name="methodSymbol">The <see cref="IMethodSymbol"/> instance representing an azure function method.</param>
/// <returns><see cref="FunctionMethodVisibility"/></returns>
internal static FunctionMethodVisibility GetVisibility(this IMethodSymbol methodSymbol)
{
// Check if the symbol itself is public
if (methodSymbol.DeclaredAccessibility == Accessibility.Public)
{
// Check if any containing type is not public
INamedTypeSymbol containingType = methodSymbol.ContainingType;
while (containingType != null)
{
if (containingType.DeclaredAccessibility != Accessibility.Public)
{
return FunctionMethodVisibility.PublicButContainingTypeNotVisible;
}
containingType = containingType.ContainingType;
}

// If both the symbol and all containing types are public, return PublicAndVisible
return FunctionMethodVisibility.Public;
}

// If the symbol itself is not public, return NotPublic
return FunctionMethodVisibility.NotPublic;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@ public partial class FunctionExecutorGenerator
{
internal static class Emitter
{
internal static string Emit(GeneratorExecutionContext context, IEnumerable<ExecutableFunction> functions, bool includeAutoRegistrationCode)
private const string WorkerCoreAssemblyName = "Microsoft.Azure.Functions.Worker.Core";

internal static string Emit(GeneratorExecutionContext context, IEnumerable<ExecutableFunction> executableFunctions, bool includeAutoRegistrationCode)
{
var functions = executableFunctions.ToList();
var defaultExecutorNeeded = functions.Any(f => f.Visibility == FunctionMethodVisibility.PublicButContainingTypeNotVisible);

string result = $$"""
// <auto-generated/>
Expand All @@ -31,7 +35,7 @@ namespace {{FunctionsUtil.GetNamespaceForGeneratedCode(context)}}
[global::System.ComponentModel.EditorBrowsableAttribute(global::System.ComponentModel.EditorBrowsableState.Never)]
internal class DirectFunctionExecutor : IFunctionExecutor
{
private readonly IFunctionActivator _functionActivator;
private readonly IFunctionActivator _functionActivator;{{(defaultExecutorNeeded ? $"{Environment.NewLine} private Lazy<IFunctionExecutor> _defaultExecutor;" : string.Empty)}}
{{GetTypesDictionary(functions)}}
public DirectFunctionExecutor(IFunctionActivator functionActivator)
{
Expand All @@ -41,8 +45,8 @@ public DirectFunctionExecutor(IFunctionActivator functionActivator)
/// <inheritdoc/>
public async ValueTask ExecuteAsync(FunctionContext context)
{
{{GetMethodBody(functions)}}
}
{{GetMethodBody(functions, defaultExecutorNeeded)}}
}{{(defaultExecutorNeeded ? $"{Environment.NewLine}{EmitCreateDefaultExecutorMethod(context)}" : string.Empty)}}
}

/// <summary>
Expand All @@ -67,20 +71,40 @@ public static IHostBuilder ConfigureGeneratedFunctionExecutor(this IHostBuilder
return result;
}

private static string EmitCreateDefaultExecutorMethod(GeneratorExecutionContext context)
{
var workerCoreAssembly = context.Compilation.SourceModule.ReferencedAssemblySymbols.Single(a => a.Name == WorkerCoreAssemblyName);
var assemblyIdentity = workerCoreAssembly.Identity;

return $$"""

private IFunctionExecutor CreateDefaultExecutorInstance(FunctionContext context)
{
var defaultExecutorFullName = "Microsoft.Azure.Functions.Worker.Invocation.DefaultFunctionExecutor, {{assemblyIdentity}}";
var defaultExecutorType = Type.GetType(defaultExecutorFullName);

return ActivatorUtilities.CreateInstance(context.InstanceServices, defaultExecutorType) as IFunctionExecutor;
}
""";
}

private static string GetTypesDictionary(IEnumerable<ExecutableFunction> functions)
{
var classNames = functions.Where(f => !f.IsStatic).Select(f => f.ParentFunctionClassName).Distinct();
if (!classNames.Any())
{
return """
// Build a dictionary of type names and its full qualified names (including assembly identity)
var typesDict = functions
.Where(f => !f.IsStatic)
.GroupBy(f => f.ParentFunctionClassName)
.ToDictionary(k => k.First().ParentFunctionClassName, v => v.First().AssemblyIdentity);

""";
if (typesDict.Count == 0)
{
return "";
}

return $$"""
private readonly Dictionary<string, Type> types = new()
{
{{string.Join($",{Environment.NewLine} ", classNames.Select(c => $$""" { "{{c}}", Type.GetType("{{c}}")! }"""))}}
{{string.Join($",{Environment.NewLine} ", typesDict.Select(c => $$""" { "{{c.Key}}", Type.GetType("{{c.Key}}, {{c.Value}}")! }"""))}}
};

""";
Expand Down Expand Up @@ -114,64 +138,84 @@ public void Configure(IHostBuilder hostBuilder)
return "";
}

private static string GetMethodBody(IEnumerable<ExecutableFunction> functions)
private static string GetMethodBody(IEnumerable<ExecutableFunction> functions, bool anyDefaultExecutor)
{
var sb = new StringBuilder();
sb.Append(
"""
$$"""
var inputBindingFeature = context.Features.Get<IFunctionInputBindingFeature>()!;
var inputBindingResult = await inputBindingFeature.BindFunctionInputAsync(context)!;
var inputArguments = inputBindingResult.Values;

{{(anyDefaultExecutor ? $" _defaultExecutor = new Lazy<IFunctionExecutor>(() => CreateDefaultExecutorInstance(context));{Environment.NewLine}" : string.Empty)}}
""");

bool first = true;

foreach (ExecutableFunction function in functions)
{
var fast = function.Visibility == FunctionMethodVisibility.Public;
sb.Append($$"""

{{(first ? string.Empty : "else ")}}if (string.Equals(context.FunctionDefinition.EntryPoint, "{{function.EntryPoint}}", StringComparison.Ordinal))
{
{{(fast ? EmitFastPath(function) : EmitSlowPath())}}
}
""");

first = false;
int functionParamCounter = 0;
var functionParamList = new List<string>();
foreach (var argumentTypeName in function.ParameterTypeNames)
{
functionParamList.Add($"({argumentTypeName})inputArguments[{functionParamCounter++}]");
}
var methodParamsStr = string.Join(", ", functionParamList);

if (!function.IsStatic)
{
sb.Append($$"""

var instanceType = types["{{function.ParentFunctionClassName}}"];
}

return sb.ToString();
}

private static string EmitFastPath(ExecutableFunction function)
{
var sb = new StringBuilder();
int functionParamCounter = 0;
var functionParamList = new List<string>();
foreach (var argumentTypeName in function.ParameterTypeNames)
{
functionParamList.Add($"({argumentTypeName})inputArguments[{functionParamCounter++}]");
}
var methodParamsStr = string.Join(", ", functionParamList);

if (!function.IsStatic)
{
sb.Append($$"""
var instanceType = types["{{function.ParentFunctionClassName}}"];
var i = _functionActivator.CreateInstance(instanceType, context) as {{function.ParentFunctionFullyQualifiedClassName}};
""");
}
}

if (!function.IsStatic)
{
sb.Append(@"
");
}
else
{
sb.Append(" ");
}

if (function.IsReturnValueAssignable)
{
sb.Append(@$"context.GetInvocationResult().Value = ");
}
if (function.ShouldAwait)
{
sb.Append("await ");
}

sb.Append(function.IsStatic
? @$"{function.ParentFunctionFullyQualifiedClassName}.{function.MethodName}({methodParamsStr});
}}"
: $@"i.{function.MethodName}({methodParamsStr});
}}");
if (function.IsReturnValueAssignable)
{
sb.Append("context.GetInvocationResult().Value = ");
}
if (function.ShouldAwait)
{
sb.Append("await ");
}

sb.Append(function.IsStatic
? $"{function.ParentFunctionFullyQualifiedClassName}.{function.MethodName}({methodParamsStr});"
: $"i.{function.MethodName}({methodParamsStr});");
return sb.ToString();
}

private static string EmitSlowPath()
{
return
" await _defaultExecutor.Value.ExecuteAsync(context);";
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,16 @@ internal class ExecutableFunction
/// A collection of fully qualified type names of the parameters of the function.
/// </summary>
internal IEnumerable<string> ParameterTypeNames { set; get; } = Enumerable.Empty<string>();

/// <summary>
/// Get a value indicating the visibility of the executable function.
/// </summary>
internal FunctionMethodVisibility Visibility { get; set; }

/// <summary>
/// Gets the assembly identity of the function.
/// ex: FooAssembly, Version=1.2.3.4, Culture=neutral, PublicKeyToken=9475d07f10cb09df
/// </summary>
internal string AssemblyIdentity { get; set; } = null!;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
// Licensed under the MIT License. See License.txt in the project root for license information.

using System.Collections.Generic;
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;

namespace Microsoft.Azure.Functions.Worker.Sdk.Generators
{
Expand All @@ -23,48 +22,38 @@ internal Parser(GeneratorExecutionContext context)

private Compilation Compilation => _context.Compilation;

internal ICollection<ExecutableFunction> GetFunctions(List<MethodDeclarationSyntax> methods)
internal ICollection<ExecutableFunction> GetFunctions(IEnumerable<IMethodSymbol> methods)
{
var functionList = new List<ExecutableFunction>();

foreach (MethodDeclarationSyntax method in methods)
foreach (IMethodSymbol method in methods.Where(m=>m.DeclaredAccessibility == Accessibility.Public))
{
_context.CancellationToken.ThrowIfCancellationRequested();
var model = Compilation.GetSemanticModel(method.SyntaxTree);

if (!FunctionsUtil.IsValidFunctionMethod(_context, Compilation, model, method))
{
continue;
}
var methodName = method.Name;
var methodParameterList = new List<string>();

var methodName = method.Identifier.Text;
var methodParameterList = new List<string>(method.ParameterList.Parameters.Count);

foreach (var methodParam in method.ParameterList.Parameters)
foreach (IParameterSymbol parameterSymbol in method.Parameters)
{
if (model.GetDeclaredSymbol(methodParam) is not IParameterSymbol parameterSymbol)
{
continue;
}

var fullyQualifiedTypeName = parameterSymbol.Type.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
methodParameterList.Add(fullyQualifiedTypeName);
}

var methodSymbol = model.GetDeclaredSymbol(method)!;
var defaultFormatClassName = methodSymbol.ContainingSymbol.ToDisplayString();
var fullyQualifiedClassName = methodSymbol.ContainingSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);
var defaultFormatClassName = method.ContainingSymbol.ToDisplayString();
var fullyQualifiedClassName = method.ContainingSymbol.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat);

var function = new ExecutableFunction
{
EntryPoint = $"{defaultFormatClassName}.{method.Identifier.ValueText}",
EntryPoint = $"{defaultFormatClassName}.{method.Name}",
ParameterTypeNames = methodParameterList,
MethodName = methodName,
ShouldAwait = IsTaskType(methodSymbol.ReturnType),
IsReturnValueAssignable = IsReturnValueAssignable(methodSymbol),
IsStatic = method.Modifiers.Any(SyntaxKind.StaticKeyword),
ShouldAwait = IsTaskType(method.ReturnType),
IsReturnValueAssignable = IsReturnValueAssignable(method),
IsStatic = method.IsStatic,
ParentFunctionClassName = defaultFormatClassName,
ParentFunctionFullyQualifiedClassName = fullyQualifiedClassName
ParentFunctionFullyQualifiedClassName = fullyQualifiedClassName,
Visibility = method.GetVisibility(),
AssemblyIdentity = method.ContainingAssembly.Identity.GetDisplayName(),
};

functionList.Add(function);
Expand Down
Loading