diff --git a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/FunctionMetadataProviderGenerator.cs b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/FunctionMetadataProviderGenerator.cs index f19102e77..7bcafdbf8 100644 --- a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/FunctionMetadataProviderGenerator.cs +++ b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/FunctionMetadataProviderGenerator.cs @@ -93,55 +93,10 @@ private IEnumerable GetEntryAssemblyFunctions(List private IEnumerable GetDependentAssemblyFunctions(GeneratorExecutionContext context) { - foreach (var assembly in context.Compilation.SourceModule.ReferencedAssemblySymbols) - { - foreach (var methodSymbol in GetMethodsFromNamespace(context.Compilation, assembly.GlobalNamespace)) - { - yield return methodSymbol; - } - } - } + var visitor = new ReferencedAssemblyMethodVisitor(context.Compilation); + visitor.Visit(context.Compilation.SourceModule); - private IEnumerable GetMethodsFromNamespace(Compilation compilation, INamespaceSymbol namespaceSymbol) - { - foreach (var member in namespaceSymbol.GetMembers()) - { - if (member is INamespaceSymbol nestedNamespace) - { - // Recursive call for nested namespaces - foreach (var methodSymbol in GetMethodsFromNamespace(compilation, nestedNamespace)) - { - yield return methodSymbol; - } - } - else if (member is INamedTypeSymbol typeSymbol) - { - // Recursive call for nested types - foreach (var methodSymbol in GetMethodsFromType(compilation, typeSymbol)) - { - yield return methodSymbol; - } - } - } - } - - private IEnumerable GetMethodsFromType(Compilation compilation, INamedTypeSymbol typeSymbol) - { - foreach (var member in typeSymbol.GetMembers()) - { - if (member is IMethodSymbol methodSymbol && FunctionsUtil.IsFunctionSymbol(methodSymbol, compilation)) - { - yield return methodSymbol; - } - else if (member is INamedTypeSymbol nestedType) - { - // Recursive call for nested types - foreach (var nestedMethodSymbol in GetMethodsFromType(compilation, nestedType)) - { - yield return nestedMethodSymbol; - } - } - } + return visitor.FunctionMethods; } } } diff --git a/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/ReferencedAssemblyMethodVisitor.cs b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/ReferencedAssemblyMethodVisitor.cs new file mode 100644 index 000000000..97c288c1e --- /dev/null +++ b/sdk/Sdk.Generators/FunctionMetadataProviderGenerator/ReferencedAssemblyMethodVisitor.cs @@ -0,0 +1,75 @@ +// Copyright (c) .NET Foundation. All rights reserved. +// Licensed under the MIT License. See License.txt in the project root for license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Microsoft.CodeAnalysis; + +namespace Microsoft.Azure.Functions.Worker.Sdk.Generators +{ + /// + /// Visits all symbols from referenced assemblies and returns all methods which are valid Azure Functions. + /// + internal sealed class ReferencedAssemblyMethodVisitor : SymbolVisitor + { + private readonly Compilation _compilation; + + /// + /// Gets all methods which are valid Azure Functions. + /// + internal readonly List FunctionMethods = new(); + + internal ReferencedAssemblyMethodVisitor(Compilation compilation) + { + _compilation = compilation ?? throw new ArgumentNullException(nameof(compilation)); + } + + public override void VisitModule(IModuleSymbol moduleSymbol) + { + foreach (var assemblySymbol in moduleSymbol.ReferencedAssemblySymbols) + { + assemblySymbol.Accept(this); + } + } + + public override void VisitAssembly(IAssemblySymbol symbol) + { + var namespaceSymbol = symbol.GlobalNamespace; + namespaceSymbol.Accept(this); + } + + public override void VisitNamespace(INamespaceSymbol symbol) + { + // Get classes in this namespace or child namespaces + var classesOrNamespaces = symbol.GetMembers() + .Where(a => a.Kind is SymbolKind.Namespace or SymbolKind.NamedType); + + foreach (var childSymbol in classesOrNamespaces) + { + childSymbol.Accept(this); + } + } + + public override void VisitNamedType(INamedTypeSymbol symbol) + { + // Get methods in this class or nested child classes + var methodsOrClasses = symbol.GetMembers() + .Where(a => a.Kind is SymbolKind.NamedType or SymbolKind.Method); + + foreach (var childSymbol in methodsOrClasses) + { + childSymbol.Accept(this); + } + } + + public override void VisitMethod(IMethodSymbol methodSymbol) + { + if (methodSymbol.MethodKind == MethodKind.Ordinary && + FunctionsUtil.IsFunctionSymbol(methodSymbol, _compilation)) + { + FunctionMethods.Add(methodSymbol); + } + } + } +}