diff --git a/readme/async-root.md b/readme/async-root.md index 2cdd97795..43a111a15 100644 --- a/readme/async-root.md +++ b/readme/async-root.md @@ -50,13 +50,13 @@ partial class Composition [MethodImpl(MethodImplOptions.AggressiveInlining)] public Task GetMyServiceAsync(CancellationToken cancellationToken) { - TaskScheduler transientTaskScheduler5 = TaskScheduler.Default; - TaskContinuationOptions transientTaskContinuationOptions4 = TaskContinuationOptions.None; - TaskCreationOptions transientTaskCreationOptions3 = TaskCreationOptions.None; TaskFactory perBlockTaskFactory2; CancellationToken localCancellationToken17 = cancellationToken; + TaskCreationOptions transientTaskCreationOptions3 = TaskCreationOptions.None; TaskCreationOptions localTaskCreationOptions18 = transientTaskCreationOptions3; + TaskContinuationOptions transientTaskContinuationOptions4 = TaskContinuationOptions.None; TaskContinuationOptions localTaskContinuationOptions19 = transientTaskContinuationOptions4; + TaskScheduler transientTaskScheduler5 = TaskScheduler.Default; TaskScheduler localTaskScheduler20 = transientTaskScheduler5; perBlockTaskFactory2 = new TaskFactory(localCancellationToken17, localTaskCreationOptions18, localTaskContinuationOptions19, localTaskScheduler20); Func perBlockFunc1 = new Func([MethodImpl(MethodImplOptions.AggressiveInlining)] () => diff --git a/readme/generic-async-composition-roots-with-constraints.md b/readme/generic-async-composition-roots-with-constraints.md index b3a4817dd..4a9edb5a8 100644 --- a/readme/generic-async-composition-roots-with-constraints.md +++ b/readme/generic-async-composition-roots-with-constraints.md @@ -84,13 +84,13 @@ partial class Composition public Task> GetOtherServiceAsync(CancellationToken cancellationToken) where T: IDisposable { - TaskScheduler transientTaskScheduler5 = TaskScheduler.Default; - TaskContinuationOptions transientTaskContinuationOptions4 = TaskContinuationOptions.None; - TaskCreationOptions transientTaskCreationOptions3 = TaskCreationOptions.None; TaskFactory> perBlockTaskFactory2; CancellationToken localCancellationToken47 = cancellationToken; + TaskCreationOptions transientTaskCreationOptions3 = TaskCreationOptions.None; TaskCreationOptions localTaskCreationOptions48 = transientTaskCreationOptions3; + TaskContinuationOptions transientTaskContinuationOptions4 = TaskContinuationOptions.None; TaskContinuationOptions localTaskContinuationOptions49 = transientTaskContinuationOptions4; + TaskScheduler transientTaskScheduler5 = TaskScheduler.Default; TaskScheduler localTaskScheduler50 = transientTaskScheduler5; perBlockTaskFactory2 = new TaskFactory>(localCancellationToken47, localTaskCreationOptions48, localTaskContinuationOptions49, localTaskScheduler50); Func> perBlockFunc1 = new Func>([MethodImpl(MethodImplOptions.AggressiveInlining)] () => @@ -116,13 +116,13 @@ partial class Composition where T: IDisposable where T1: struct { - TaskScheduler transientTaskScheduler5 = TaskScheduler.Default; - TaskContinuationOptions transientTaskContinuationOptions4 = TaskContinuationOptions.None; - TaskCreationOptions transientTaskCreationOptions3 = TaskCreationOptions.None; TaskFactory> perBlockTaskFactory2; CancellationToken localCancellationToken55 = cancellationToken; + TaskCreationOptions transientTaskCreationOptions3 = TaskCreationOptions.None; TaskCreationOptions localTaskCreationOptions56 = transientTaskCreationOptions3; + TaskContinuationOptions transientTaskContinuationOptions4 = TaskContinuationOptions.None; TaskContinuationOptions localTaskContinuationOptions57 = transientTaskContinuationOptions4; + TaskScheduler transientTaskScheduler5 = TaskScheduler.Default; TaskScheduler localTaskScheduler58 = transientTaskScheduler5; perBlockTaskFactory2 = new TaskFactory>(localCancellationToken55, localTaskCreationOptions56, localTaskContinuationOptions57, localTaskScheduler58); Func> perBlockFunc1 = new Func>([MethodImpl(MethodImplOptions.AggressiveInlining)] () => diff --git a/readme/simplified-factory.md b/readme/simplified-factory.md index 1d759c139..6a3450ad4 100644 --- a/readme/simplified-factory.md +++ b/readme/simplified-factory.md @@ -81,9 +81,9 @@ partial class Composition [MethodImpl(MethodImplOptions.AggressiveInlining)] get { - DateTimeOffset transientDateTimeOffset3 = DateTimeOffset.Now; Dependency transientDependency1; Dependency localDependency29 = new Dependency(); + DateTimeOffset transientDateTimeOffset3 = DateTimeOffset.Now; DateTimeOffset localTime30 = transientDateTimeOffset3; localDependency29.Initialize(localTime30); transientDependency1 = localDependency29; diff --git a/readme/task.md b/readme/task.md index b7d0ab24e..a6d2aeaf2 100644 --- a/readme/task.md +++ b/readme/task.md @@ -79,13 +79,13 @@ partial class Composition [MethodImpl(MethodImplOptions.AggressiveInlining)] public IService GetRoot(CancellationToken cancellationToken) { - TaskScheduler transientTaskScheduler6 = TaskScheduler.Current; - TaskContinuationOptions transientTaskContinuationOptions5 = TaskContinuationOptions.None; - TaskCreationOptions transientTaskCreationOptions4 = TaskCreationOptions.None; TaskFactory perBlockTaskFactory3; CancellationToken localCancellationToken39 = cancellationToken; + TaskCreationOptions transientTaskCreationOptions4 = TaskCreationOptions.None; TaskCreationOptions localTaskCreationOptions40 = transientTaskCreationOptions4; + TaskContinuationOptions transientTaskContinuationOptions5 = TaskContinuationOptions.None; TaskContinuationOptions localTaskContinuationOptions41 = transientTaskContinuationOptions5; + TaskScheduler transientTaskScheduler6 = TaskScheduler.Current; TaskScheduler localTaskScheduler42 = transientTaskScheduler6; perBlockTaskFactory3 = new TaskFactory(localCancellationToken39, localTaskCreationOptions40, localTaskContinuationOptions41, localTaskScheduler42); Func perBlockFunc2 = new Func([MethodImpl(MethodImplOptions.AggressiveInlining)] () => diff --git a/src/Pure.DI.Core/Core/Code/BlockCodeBuilder.cs b/src/Pure.DI.Core/Core/Code/BlockCodeBuilder.cs index b59d2d7e4..528d11af4 100644 --- a/src/Pure.DI.Core/Core/Code/BlockCodeBuilder.cs +++ b/src/Pure.DI.Core/Core/Code/BlockCodeBuilder.cs @@ -67,7 +67,11 @@ public void Build(BuildContext ctx, in Block block) var content = new LinesBuilder(); foreach (var statement in block.Statements) { - ctx.StatementBuilder.Build(ctx with { Variable = statement.Current, Code = content }, statement); + var curVar = statement.Current; + if (ctx.IsFactory || curVar.Injection.Kind != InjectionKind.FactoryInjection || curVar.Node.Lifetime != Lifetime.Transient) + { + ctx.StatementBuilder.Build(ctx with { Variable = curVar, Code = content, IsFactory = false }, statement); + } } if (content.Count == 0) diff --git a/src/Pure.DI.Core/Core/Code/BuildContext.cs b/src/Pure.DI.Core/Core/Code/BuildContext.cs index 24e66bd70..d8f0fb915 100644 --- a/src/Pure.DI.Core/Core/Code/BuildContext.cs +++ b/src/Pure.DI.Core/Core/Code/BuildContext.cs @@ -10,4 +10,5 @@ internal record BuildContext( LinesBuilder LocalFunctionsCode, object? ContextTag, bool? LockIsRequired, - ImmutableArray Accumulators); \ No newline at end of file + ImmutableArray Accumulators, + bool IsFactory = false); \ No newline at end of file diff --git a/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs b/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs index fc1fa9283..fb2992574 100644 --- a/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs +++ b/src/Pure.DI.Core/Core/Code/FactoryCodeBuilder.cs @@ -198,8 +198,8 @@ public void Build(BuildContext ctx, in DpFactory factory) lines.AddRange(text.Lines); } - var injectionArgs = variable.Args.Where(i => i.Current.Injection.Kind == InjectionKind.Injection).ToList(); - var initializationArgs = variable.Args.Where(i => i.Current.Injection.Kind != InjectionKind.Injection).ToList(); + var injectionArgs = variable.Args.Where(i => i.Current.Injection.Kind is InjectionKind.FactoryInjection).ToList(); + var initializationArgs = variable.Args.Where(i => i.Current.Injection.Kind != InjectionKind.FactoryInjection).ToList(); // Replaces injection markers by injection code if (injectionArgs.Count != injections.Count) @@ -276,7 +276,7 @@ public void Build(BuildContext ctx, in DpFactory factory) var indent = prefixes.Count; using (code.Indent(indent)) { - ctx.StatementBuilder.Build(injectionsCtx with { Level = level, Variable = argument.Current, LockIsRequired = lockIsRequired }, argument); + ctx.StatementBuilder.Build(injectionsCtx with { Level = level, Variable = argument.Current, LockIsRequired = lockIsRequired, IsFactory = true }, argument); code.AppendLine($"{(resolver.DeclarationRequired ? $"{typeResolver.Resolve(ctx.DependencyGraph.Source, argument.Current.Injection.Type)} " : "")}{resolver.VariableName} = {ctx.BuildTools.OnInjected(ctx, argument.Current)};"); } diff --git a/src/Pure.DI.Core/Core/Code/FactoryRewriter.cs b/src/Pure.DI.Core/Core/Code/FactoryRewriter.cs index eb3022d19..97aea7c89 100644 --- a/src/Pure.DI.Core/Core/Code/FactoryRewriter.cs +++ b/src/Pure.DI.Core/Core/Code/FactoryRewriter.cs @@ -128,76 +128,61 @@ private ExpressionStatementSyntax CreateAssignmentExpression(SyntaxNode returnBo SyntaxFactory.IdentifierName(variable.VariableName).WithLeadingTrivia(SyntaxFactory.Space).WithTrailingTrivia(SyntaxFactory.Space), (ExpressionSyntax)Visit(returnBody).WithLeadingTrivia(SyntaxFactory.Space))), owner); - - public override SyntaxNode? VisitInvocationExpression(InvocationExpressionSyntax invocation) + + public override SyntaxNode VisitExpressionStatement(ExpressionStatementSyntax node) { - if (invocation.ArgumentList.Arguments.Count > 0) - { - if (invocation.Expression is MemberAccessExpressionSyntax - { - Name: GenericNameSyntax - { - Identifier.Text: nameof(IContext.Inject), - TypeArgumentList.Arguments: [not null] - }, - Expression: IdentifierNameSyntax ctx - } - && ctx.Identifier.Text == factory.Source.Context.Identifier.Text - && TryInject(invocation, out var visitInvocationExpression)) + node = (ExpressionStatementSyntax)base.VisitExpressionStatement(node)!; + if (node.Expression is not InvocationExpressionSyntax { - return visitInvocationExpression; - } + ArgumentList.Arguments.Count: > 0, + Expression: MemberAccessExpressionSyntax { Expression: IdentifierNameSyntax ctx } memberAccessExpression + } invocation + || ctx.Identifier.Text != factory.Source.Context.Identifier.Text) + { + return node; + } - if (invocation.Expression is MemberAccessExpressionSyntax - { - Name: IdentifierNameSyntax - { - Identifier.Text: nameof(IContext.Inject) - }, - Expression: IdentifierNameSyntax ctx2 - } - && ctx2.Identifier.Text == factory.Source.Context.Identifier.Text - && TryInject(invocation, out visitInvocationExpression)) - { - return visitInvocationExpression; - } + var name = ""; + switch (memberAccessExpression.Name) + { + case GenericNameSyntax { TypeArgumentList.Arguments.Count: 1 } genericName: + name = genericName.Identifier.Text; + break; - if (invocation.Expression is MemberAccessExpressionSyntax - { - Name: GenericNameSyntax - { - Identifier.Text: nameof(IContext.BuildUp), - TypeArgumentList.Arguments: [not null] - }, - Expression: IdentifierNameSyntax ctx3 - } - && ctx3.Identifier.Text == factory.Source.Context.Identifier.Text - && TryInitialize(invocation, out visitInvocationExpression)) - { - return visitInvocationExpression; - } + case IdentifierNameSyntax identifierName: + name = identifierName.Identifier.Text; + break; + } - if (invocation.Expression is MemberAccessExpressionSyntax - { - Name: IdentifierNameSyntax - { - Identifier.Text: nameof(IContext.BuildUp) - }, - Expression: IdentifierNameSyntax ctx4 - } - && ctx4.Identifier.Text == factory.Source.Context.Identifier.Text - && TryInitialize(invocation, out visitInvocationExpression)) - { - return visitInvocationExpression; - } + ExpressionSyntax? expressionSyntax = default; + var processed = name switch + { + nameof(IContext.Inject) => TryInject(invocation, out expressionSyntax), + nameof(IContext.BuildUp) => TryInitialize(invocation, out expressionSyntax), + _ => false + }; + + if (!processed || expressionSyntax is null) + { + return node; + } + + SyntaxNode newNode; + if (node.Parent is null or BlockSyntax) + { + newNode = SyntaxFactory.ExpressionStatement(expressionSyntax).WithLeadingTrivia(SyntaxFactory.LineFeed).WithTrailingTrivia(SyntaxFactory.LineFeed); + } + else + { + newNode = SyntaxFactory.Block().AddStatements(SyntaxFactory.ExpressionStatement(expressionSyntax).WithLeadingTrivia(SyntaxFactory.LineFeed).WithTrailingTrivia(SyntaxFactory.LineFeed)); } - return base.VisitInvocationExpression(invocation); + return triviaTools.PreserveTrivia(newNode, node); } private bool TryInject( InvocationExpressionSyntax invocation, - [NotNullWhen(true)] out SyntaxNode? visitInvocationExpression) + [NotNullWhen(true)] out ExpressionSyntax? expressionSyntax) { var value = invocation.ArgumentList.Arguments.Count switch { @@ -211,25 +196,25 @@ private bool TryInject( case IdentifierNameSyntax identifierName: injections.Add(new Injection(identifierName.Identifier.Text, false)); { - visitInvocationExpression = triviaTools.PreserveTrivia(InjectionMarkerExpression, invocation); + expressionSyntax = triviaTools.PreserveTrivia(InjectionMarkerExpression, invocation); return true; } case DeclarationExpressionSyntax { Designation: SingleVariableDesignationSyntax singleVariableDesignationSyntax }: injections.Add(new Injection(singleVariableDesignationSyntax.Identifier.Text, true)); { - visitInvocationExpression = triviaTools.PreserveTrivia(InjectionMarkerExpression, invocation); + expressionSyntax = triviaTools.PreserveTrivia(InjectionMarkerExpression, invocation); return true; } } - visitInvocationExpression = default; + expressionSyntax = default; return false; } private bool TryInitialize( InvocationExpressionSyntax invocation, - [NotNullWhen(true)] out SyntaxNode? visitInvocationExpression) + [NotNullWhen(true)] out ExpressionSyntax? expressionSyntax) { var value = invocation.ArgumentList.Arguments.Count switch { @@ -242,33 +227,22 @@ private bool TryInitialize( case IdentifierNameSyntax identifierName: initializers.Add(new Initializer(identifierName.Identifier.Text, false)); { - visitInvocationExpression = triviaTools.PreserveTrivia(InitializationMarkerExpression, invocation); + expressionSyntax = triviaTools.PreserveTrivia(InitializationMarkerExpression, invocation); return true; } case DeclarationExpressionSyntax { Designation: SingleVariableDesignationSyntax singleVariableDesignationSyntax }: initializers.Add(new Initializer(singleVariableDesignationSyntax.Identifier.Text, true)); { - visitInvocationExpression = triviaTools.PreserveTrivia(InitializationMarkerExpression, invocation); + expressionSyntax = triviaTools.PreserveTrivia(InitializationMarkerExpression, invocation); return true; } } - visitInvocationExpression = default; + expressionSyntax = default; return false; } - public override SyntaxNode VisitExpressionStatement(ExpressionStatementSyntax node) - { - var newNode = (ExpressionStatementSyntax)base.VisitExpressionStatement(node)!; - if (newNode.Expression.IsEquivalentTo(InjectionMarkerExpression)) - { - return triviaTools.PreserveTrivia(newNode, node); - } - - return newNode; - } - public override SyntaxNode? VisitMemberAccessExpression(MemberAccessExpressionSyntax node) { if (node.IsKind(SyntaxKind.SimpleMemberAccessExpression) diff --git a/src/Pure.DI.Core/Core/FactoryDependencyNodeBuilder.cs b/src/Pure.DI.Core/Core/FactoryDependencyNodeBuilder.cs index 299669fd7..e98d0091f 100644 --- a/src/Pure.DI.Core/Core/FactoryDependencyNodeBuilder.cs +++ b/src/Pure.DI.Core/Core/FactoryDependencyNodeBuilder.cs @@ -19,7 +19,7 @@ public IEnumerable Build(MdSetup setup) foreach (var resolver in factory.Resolvers) { var tag = attributes.GetAttribute(resolver.SemanticModel, setup.TagAttributes, resolver.Attributes, default(object?)) ?? resolver.Tag?.Value; - injections.Add(new Injection(InjectionKind.Injection, resolver.ContractType.WithNullableAnnotation(NullableAnnotation.NotAnnotated), tag)); + injections.Add(new Injection(InjectionKind.FactoryInjection, resolver.ContractType.WithNullableAnnotation(NullableAnnotation.NotAnnotated), tag)); } var compilation = binding.SemanticModel.Compilation; diff --git a/src/Pure.DI.Core/Core/Models/InjectionKind.cs b/src/Pure.DI.Core/Core/Models/InjectionKind.cs index 2bb1a2baa..2f802d2f4 100644 --- a/src/Pure.DI.Core/Core/Models/InjectionKind.cs +++ b/src/Pure.DI.Core/Core/Models/InjectionKind.cs @@ -3,10 +3,16 @@ internal enum InjectionKind { Field, + Property, + Parameter, + Root, - Injection, + + FactoryInjection, + Contract, + Construct } \ No newline at end of file diff --git a/tests/Pure.DI.IntegrationTests/FactoryTests.cs b/tests/Pure.DI.IntegrationTests/FactoryTests.cs index ffc864bc1..dc7ba39de 100644 --- a/tests/Pure.DI.IntegrationTests/FactoryTests.cs +++ b/tests/Pure.DI.IntegrationTests/FactoryTests.cs @@ -453,7 +453,12 @@ namespace Sample { interface IDependency {} - class Dependency: IDependency {} + class Dependency: IDependency + { + public Dependency(int id, string str) + { + } + } interface IService { @@ -475,14 +480,15 @@ static class Setup private static void SetupComposition() { DI.Setup("Composition") - .Bind().To(ctx => new Dependency()) + .Bind().To(_ => "Abc") + .Bind().To(_ => 33) + .Bind(22).To() + .Bind().To(ctx => new Dependency(0, "xyz")) .Bind().To(ctx => { IDependency dependency1; var rnd = new Random(1).Next(3); if (rnd == 0) - { - ctx.Inject(out dependency1); - } + ctx.Inject(22, out dependency1); else { if (rnd == 1) @@ -1591,6 +1597,82 @@ public static void Main() result.StdOut.ShouldBe(["Initialize Abc", "Id: 33"], result); } + [Fact] + public async Task ShouldSupportInitializationWhenMethodAndPropertyInIf() + { + // Given + + // When + var result = await """ + using System; + using Pure.DI; + + namespace Sample + { + class Dependency + { + [Ordinal(1)] + internal void Initialize([Tag(374)] string depName) + { + Console.WriteLine($"Initialize {depName}"); + } + + [Ordinal(2)] + public int Id { get; set; } + } + + interface IService + { + Dependency Dep { get; } + } + + class Service: IService + { + public Service(Dependency dep) + { + Dep = dep; + } + + public Dependency Dep { get; } + } + + static class Setup + { + private static void SetupComposition() + { + DI.Setup("Composition") + .Bind(374).To(_ => "Abc") + .Bind().To(_ => 33) + .Bind().To(ctx => { + var dep = new Dependency(); + if (true) + { + ctx.BuildUp(dep); + } + return dep; + }) + .Bind().To() + .Root("Service"); + } + } + + public class Program + { + public static void Main() + { + var composition = new Composition(); + var service = composition.Service; + Console.WriteLine($"Id: {service.Dep.Id}"); + } + } + } + """.RunAsync(); + + // Then + result.Success.ShouldBeTrue(result); + result.StdOut.ShouldBe(["Initialize Abc", "Id: 33"], result); + } + [Fact] public async Task ShouldSupportInitializationWhenMethodAndField() {