Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compile C# expressions on demand #315

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/Test/TestCases.Workflows/ExpressionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,25 @@ public void Vb_CompareLambdas()
validationResults.Errors[0].Message.ShouldContain("A null propagating operator cannot be converted into an expression tree.");
}

[Fact]
public void Cs_CompareLambdas()
{
CSharpValue<string> csv = new(@$"string.Concat(""alpha "", b?.Substring(0, 10), ""beta "", 1)");
WriteLine writeLine = new();
writeLine.Text = new InArgument<string>(csv);
Sequence workflow = new();
workflow.Activities.Add(writeLine);
workflow.Variables.Add(new Variable<string>("b", "I'm a variable"));

ValidationResults validationResults = ActivityValidationServices.Validate(workflow, _forceCache);
validationResults.Errors.Count.ShouldBe(1, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
validationResults.Errors[0].Message.ShouldContain("An expression tree lambda may not contain a null propagating operator.");

validationResults = ActivityValidationServices.Validate(workflow, _useValidator);
validationResults.Errors.Count.ShouldBe(1, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
validationResults.Errors[0].Message.ShouldContain("An expression tree lambda may not contain a null propagating operator.");
}

[Fact]
public void Vb_LambdaExtension()
{
Expand All @@ -172,6 +191,20 @@ public void Vb_LambdaExtension()
validationResults.Errors.Count.ShouldBe(0, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
}

[Fact]
public void Cs_LambdaExtension()
{
CSharpValue<string> csv = new("list.First()");
WriteLine writeLine = new();
writeLine.Text = new InArgument<string>(csv);
Sequence workflow = new();
workflow.Activities.Add(writeLine);
workflow.Variables.Add(new Variable<List<string>>("list"));

ValidationResults validationResults = ActivityValidationServices.Validate(workflow, _useValidator);
validationResults.Errors.Count.ShouldBe(0, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
}

[Fact]
public void Vb_Dictionary()
{
Expand All @@ -185,6 +218,20 @@ public void Vb_Dictionary()
ValidationResults validationResults = ActivityValidationServices.Validate(workflow, _useValidator);
validationResults.Errors.Count.ShouldBe(0, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
}

[Fact]
public void Cs_Dictionary()
{
CSharpValue<string> csv = new("something.FooDictionary[\"key\"].ToString()");
WriteLine writeLine = new();
writeLine.Text = new InArgument<string>(csv);
Sequence workflow = new();
workflow.Activities.Add(writeLine);
workflow.Variables.Add(new Variable<ClassWithCollectionProperties>("something"));

ValidationResults validationResults = ActivityValidationServices.Validate(workflow, _useValidator);
validationResults.Errors.Count.ShouldBe(0, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
}
#region Check locations are not readonly
[Fact]
public void VB_Readonly_ThrowsError()
Expand Down Expand Up @@ -291,6 +338,20 @@ public void Vb_IntOverflow()
validationResults.Errors[0].Message.ShouldContain("Constant expression not representable in type 'Integer'");
}

[Fact]
public void Cs_IntOverflow()
{
VisualBasicValue<int> csv = new("2147483648");
Sequence workflow = new();
workflow.Variables.Add(new Variable<int>("someint"));
Assign assign = new() { To = new OutArgument<int>(workflow.Variables[0]), Value = new InArgument<int>(csv) };
workflow.Activities.Add(assign);

ValidationResults validationResults = ActivityValidationServices.Validate(workflow, _useValidator);
validationResults.Errors.Count.ShouldBe(1, string.Join("\n", validationResults.Errors.Select(e => e.Message)));
validationResults.Errors[0].Message.ShouldContain("Constant expression not representable in type 'Integer'");
}

[Fact]
public void VBValidator_StrictOn()
{
Expand Down
19 changes: 18 additions & 1 deletion src/Test/TestCases.Workflows/WF4Samples/Expressions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ public class AheadOfTimeExpressions : ExpressionsBase
static readonly string CSharpCalculationResult = "Result == XX^2" + Environment.NewLine;
static readonly StringDictionary CSharpCalculationInputs = new() { ["XX"] = 16, ["YY"] = 16 };
[Fact]
public void SameTextDifferentTypes()
public void VBSameTextDifferentTypes()
{
var text = new VisualBasicValue<string>("var");
var values = new VisualBasicValue<IEnumerable<char>>("var");
Expand All @@ -101,6 +101,23 @@ public void SameTextDifferentTypes()
((LambdaExpression)values.GetExpressionTree()).ReturnType.ShouldBe(typeof(IEnumerable<char>));
}
[Fact]
public void CSSameTextDifferentTypes()
{
var text = new CSharpValue<string>("var");
var values = new CSharpValue<IEnumerable<char>>("var");
var root = new DynamicActivity
{
Implementation = () => new Sequence
{
Variables = { new Variable<string>("var") },
Activities = { new ForEach<char> { Values = new InArgument<IEnumerable<char>>(values) }, new WriteLine { Text = new InArgument<string>(text) } }
}
};
ActivityXamlServices.Compile(root, new());
((LambdaExpression)text.GetExpressionTree()).ReturnType.ShouldBe(typeof(string));
((LambdaExpression)values.GetExpressionTree()).ReturnType.ShouldBe(typeof(IEnumerable<char>));
}
[Fact]
public void CompileCSharpCalculation()
{
var activity = Compile(TestXamls.CSharpCalculation);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
using System.Activities.ExpressionParser;
using System.Activities.Expressions;
using System.Collections.Generic;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;
using Microsoft.VisualBasic.Activities;
Expand All @@ -22,7 +23,38 @@ protected override JustInTimeCompiler CreateCompiler(HashSet<Assembly> reference
public CSharpHelper(string expressionText, HashSet<AssemblyReference> assemblyReferences,
HashSet<string> namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { }

private CSharpHelper(string expressionText) : base(expressionText) { }

internal const string Language = "C#";

public static Expression<Func<ActivityContext, T>> Compile<T>(string expressionText, CodeActivityPublicEnvironmentAccessor publicAccessor, bool isLocationExpression)
{
GetAllImportReferences(publicAccessor.ActivityMetadata.CurrentActivity, false,
out var localNamespaces,
out var localAssemblies);

var helper = new CSharpHelper(expressionText);
var localReferenceAssemblies = new HashSet<AssemblyReference>();
var localImports = new HashSet<string>(localNamespaces);
foreach (var assemblyReference in localAssemblies)
{
if (assemblyReference.Assembly != null)
{
// directly add the Assembly to the list
// so that we don't have to go through
// the assembly resolution process
helper.ReferencedAssemblies ??= new HashSet<Assembly>();
helper.ReferencedAssemblies.Add(assemblyReference.Assembly);
}
else if (assemblyReference.AssemblyName != null)
{
localReferenceAssemblies.Add(assemblyReference);
}
}

helper.Initialize(localReferenceAssemblies, localImports);
return helper.Compile<T>(publicAccessor, isLocationExpression);
}
}

internal class CSharpExpressionFactory<T> : ExpressionFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@

using System;
using System.Activities;
using System.Activities.ExpressionParser;
using System.Activities.Expressions;
using System.Activities.Internals;
using System.Activities.Runtime;
using System.Activities.Validation;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq.Expressions;
using System.Windows.Markup;
using ActivityContext = System.Activities.ActivityContext;

namespace Microsoft.CSharp.Activities;

Expand All @@ -18,6 +21,8 @@ namespace Microsoft.CSharp.Activities;
public class CSharpReference<TResult> : TextExpressionBase<Location<TResult>>, ITextExpression
{
private CompiledExpressionInvoker _invoker;
private LocationFactory<TResult> _locationFactory;
private Expression<Func<ActivityContext, TResult>> _expressionTree;

public CSharpReference() => UseOldFastPath = true;

Expand All @@ -28,13 +33,92 @@ public class CSharpReference<TResult> : TextExpressionBase<Location<TResult>>, I
[DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden)]
public override string Language => CSharpHelper.Language;

public override Expression GetExpressionTree() => IsMetadataCached ? _invoker.GetExpressionTree() : throw FxTrace.Exception.AsError(new InvalidOperationException(SR.ActivityIsUncached));
public override Expression GetExpressionTree()
{
if (IsMetadataCached)
{
if (_expressionTree == null)
{
if (_invoker != null)
{
return _invoker.GetExpressionTree();
}
// it's safe to create this CodeActivityMetadata here,
// because we know we are using it only as lookup purpose.
var metadata = new CodeActivityMetadata(this, GetParentEnvironment(), false);
var publicAccessor = CodeActivityPublicEnvironmentAccessor.CreateWithoutArgument(metadata);
try
{
_expressionTree = CompileLocationExpression(publicAccessor, out var validationError);
if (validationError != null)
{
throw FxTrace.Exception.AsError(new InvalidOperationException(SR.ExpressionTamperedSinceLastCompiled(validationError)));
}
}
finally
{
metadata.Dispose();
}
}
Fx.Assert(_expressionTree.NodeType == ExpressionType.Lambda, "Lambda expression required");
return ExpressionUtilities.RewriteNonCompiledExpressionTree(_expressionTree);
}
throw FxTrace.Exception.AsError(new InvalidOperationException(SR.ActivityIsUncached));
}

protected override Location<TResult> Execute(CodeActivityContext context)
{
if (_expressionTree == null)
{
return (Location<TResult>)_invoker.InvokeExpression(context);
}
_locationFactory ??= ExpressionUtilities.CreateLocationFactory<TResult>(_expressionTree);
return _locationFactory.CreateLocation(context);
}

protected override void CacheMetadata(CodeActivityMetadata metadata)
{
_expressionTree = null;
_invoker = new CompiledExpressionInvoker(this, true, metadata);
QueueForValidation<TResult>(metadata, true);

if (QueueForValidation<TResult>(metadata, true))
{
return;
}
// If ICER is not implemented that means we haven't been compiled
var publicAccessor = CodeActivityPublicEnvironmentAccessor.Create(metadata);
_expressionTree = CompileLocationExpression(publicAccessor, out var validationError);
if (validationError != null)
{
metadata.AddValidationError(validationError);
}
}

protected override Location<TResult> Execute(CodeActivityContext context) => (Location<TResult>)_invoker.InvokeExpression(context);
private Expression<Func<ActivityContext, TResult>> CompileLocationExpression(CodeActivityPublicEnvironmentAccessor publicAccessor, out string validationError)
{
Expression<Func<ActivityContext, TResult>> expressionTreeToReturn = null;
validationError = null;
try
{
expressionTreeToReturn = CSharpHelper.Compile<TResult>(ExpressionText, publicAccessor, true);
// inspect the expressionTree to see if it is a valid location expression(L-value)
string extraErrorMessage = null;
if (!publicAccessor.ActivityMetadata.HasViolations && (expressionTreeToReturn == null ||
!ExpressionUtilities.IsLocation(expressionTreeToReturn, typeof(TResult), out extraErrorMessage)))
{
var errorMessage = SR.InvalidLValueExpression;
if (extraErrorMessage != null)
{
errorMessage += ":" + extraErrorMessage;
}
expressionTreeToReturn = null;
validationError = SR.CompilerErrorSpecificExpression(ExpressionText, errorMessage);
}
}
catch (SourceExpressionException e)
{
validationError = e.Message;
}
return expressionTreeToReturn;
}
}
66 changes: 63 additions & 3 deletions src/UiPath.Workflow/Microsoft/CSharp/Activities/CSharpValue.cs
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
// This file is part of Core WF which is licensed under the MIT license.
// See LICENSE file in the project root for full license information.

using Microsoft.VisualBasic.Activities;
using System;
using System.Activities;
using System.Activities.ExpressionParser;
using System.Activities.Expressions;
using System.Activities.Internals;
using System.Activities.Runtime;
using System.Activities.Validation;
using System.ComponentModel;
using System.Diagnostics;
using System.Linq.Expressions;
using System.Windows.Markup;
using ActivityContext = System.Activities.ActivityContext;

namespace Microsoft.CSharp.Activities;

Expand All @@ -18,6 +22,8 @@ namespace Microsoft.CSharp.Activities;
public class CSharpValue<TResult> : TextExpressionBase<TResult>
{
private CompiledExpressionInvoker _invoker;
private Func<ActivityContext, TResult> _compiledExpression;
private Expression<Func<ActivityContext, TResult>> _expressionTree;

public CSharpValue() => UseOldFastPath = true;

Expand All @@ -28,13 +34,67 @@ public class CSharpValue<TResult> : TextExpressionBase<TResult>
[DesignerSerializationVisibility(DesignerSerializationVisibility.Hidden)]
public override string Language => CSharpHelper.Language;

public override Expression GetExpressionTree() => IsMetadataCached ? _invoker.GetExpressionTree() : throw FxTrace.Exception.AsError(new InvalidOperationException(SR.ActivityIsUncached));
public override Expression GetExpressionTree()
{
if (!IsMetadataCached)
{
throw FxTrace.Exception.AsError(new InvalidOperationException(SR.ActivityIsUncached));
}
if (_expressionTree == null)
{
if (_invoker != null)
{
return _invoker.GetExpressionTree();
}
// it's safe to create this CodeActivityMetadata here,
// because we know we are using it only as lookup purpose.
var metadata = new CodeActivityMetadata(this, GetParentEnvironment(), false);
var publicAccessor = CodeActivityPublicEnvironmentAccessor.CreateWithoutArgument(metadata);
try
{
_expressionTree = CSharpHelper.Compile<TResult>(ExpressionText, publicAccessor, false);
}
catch (SourceExpressionException e)
{
throw FxTrace.Exception.AsError(
new InvalidOperationException(SR.ExpressionTamperedSinceLastCompiled(e.Message)));
}
finally
{
metadata.Dispose();
}
}
Fx.Assert(_expressionTree.NodeType == ExpressionType.Lambda, "Lambda expression required");
return ExpressionUtilities.RewriteNonCompiledExpressionTree(_expressionTree);
}

protected override void CacheMetadata(CodeActivityMetadata metadata)
{
_expressionTree = null;
_invoker = new CompiledExpressionInvoker(this, false, metadata);
QueueForValidation<TResult>(metadata, false);

if (QueueForValidation<TResult>(metadata, false))
{
return;
}
try
{
var publicAccessor = CodeActivityPublicEnvironmentAccessor.Create(metadata);
_expressionTree = CSharpHelper.Compile<TResult>(ExpressionText, publicAccessor, false);
}
catch (SourceExpressionException e)
{
metadata.AddValidationError(e.Message);
}
}

protected override TResult Execute(CodeActivityContext context) => (TResult) _invoker.InvokeExpression(context);
protected override TResult Execute(CodeActivityContext context)
{
if (_expressionTree == null)
{
return (TResult)_invoker.InvokeExpression(context);
}
_compiledExpression ??= _expressionTree.Compile();
return _compiledExpression(context);
}
}