diff --git a/src/Test/TestCases.Workflows/ExpressionTests.cs b/src/Test/TestCases.Workflows/ExpressionTests.cs index 21dc74f5..bf309dd2 100644 --- a/src/Test/TestCases.Workflows/ExpressionTests.cs +++ b/src/Test/TestCases.Workflows/ExpressionTests.cs @@ -535,4 +535,35 @@ public void CSharpReferenceTypeIsCheckedForGenerics() var result = ActivityValidationServices.Validate(sequence, _useValidator); result.Errors.ShouldBeEmpty(); } + + [Fact] + public void CS_IsCaseSensitive_WhenSearchingForVariable_Throws() + { + var seq = new Sequence(); + seq.Variables.Add(new Variable("ABC")); + seq.Activities.Add(new WriteLine { Text = new InArgument(new CSharpValue("abc")) }); + var valid = ActivityValidationServices.Validate(seq, _useValidator); + valid.Errors.Count.ShouldBe(1); + valid.Errors.First().Message.ShouldBe("CS0103: The name 'abc' does not exist in the current context"); + } + + [Fact] + public void CS_IsCaseSensitive_WhenSearchingForVariable_Succeeds() + { + var seq = new Sequence(); + seq.Variables.Add(new Variable("ABC")); + seq.Activities.Add(new WriteLine { Text = new InArgument(new CSharpValue("ABC")) }); + var valid = ActivityValidationServices.Validate(seq, _useValidator); + valid.Errors.ShouldBeEmpty(); + } + + [Fact] + public void VB_IsCaseInsensitive() + { + var seq = new Sequence(); + seq.Variables.Add(new Variable("ABC")); + seq.Activities.Add(new WriteLine { Text = new InArgument(new VisualBasicValue("abc")) }); + var valid = ActivityValidationServices.Validate(seq, _useValidator); + valid.Errors.ShouldBeEmpty(); + } } \ No newline at end of file diff --git a/src/Test/TestCases.Workflows/JitCompilerTests.cs b/src/Test/TestCases.Workflows/JitCompilerTests.cs index 52f5e54c..f1321334 100644 --- a/src/Test/TestCases.Workflows/JitCompilerTests.cs +++ b/src/Test/TestCases.Workflows/JitCompilerTests.cs @@ -50,7 +50,7 @@ public void VisualBasicJitCompiler_FieldAccess() [Fact] public void VisualBasicJitCompiler_PropertyAccess_SameNameAsVariable() { - static Type VariableTypeGetter(string name) + static Type VariableTypeGetter(string name, StringComparison stringComparison) { return name switch { @@ -69,7 +69,7 @@ static Type VariableTypeGetter(string name) [InlineData(17)] public void VisualBasicJitCompiler_ExpressionWithMultipleVariablesVariables(int noOfVar) { - static Type VariableTypeGetter(string name) + static Type VariableTypeGetter(string name, StringComparison stringComparison) { return name switch { @@ -85,7 +85,7 @@ static Type VariableTypeGetter(string name) [Fact] public void CSharpJitCompiler_PropertyAccess() { - static Type VariableTypeGetter(string name) + static Type VariableTypeGetter(string name, StringComparison stringComparison) { return name switch { @@ -108,7 +108,7 @@ static Type VariableTypeGetter(string name) [InlineData(17)] public void CSharpJitCompiler_ExpressionWithMultipleVariablesVariables(int noOfVar) { - static Type VariableTypeGetter(string name) + static Type VariableTypeGetter(string name, StringComparison stringComparison) { return name switch { @@ -125,7 +125,7 @@ static Type VariableTypeGetter(string name) public void VbExpression_UndeclaredObject() { var expressionToCompile = "new UndeclaredClass()"; - var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s)=>null, typeof(object))); + var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s, c)=>null, typeof(object))); Assert.ThrowsAny(sut); } @@ -134,7 +134,7 @@ public void VbExpression_UndeclaredObject() public void VbExpression_WithObjectInitializer() { var expressionToCompile = "new TestIndexerClass() With {.Field=\"1\"}"; - var result = _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s)=>null, typeof(TestIndexerClass))); + var result = _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s, c)=>null, typeof(TestIndexerClass))); result.ReturnType.ShouldBe(typeof(TestIndexerClass)); } @@ -142,11 +142,11 @@ public void VbExpression_WithObjectInitializer() public void VbExpression_UndeclaredObjectWithObjectInitializer() { var expressionToCompile = "new UndeclaredClass() With {.Field=\"1\"}"; - var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s)=>null, typeof(object))); + var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s, c)=>null, typeof(object))); Assert.ThrowsAny(sut); } - private static Type VariableTypeGetter(string name) + private static Type VariableTypeGetter(string name, StringComparison stringComparison) => name switch { "testIndexerClass" => typeof(TestIndexerClass), diff --git a/src/Test/TestCases.Workflows/XamlTests.cs b/src/Test/TestCases.Workflows/XamlTests.cs index b2191424..eb2d95c6 100644 --- a/src/Test/TestCases.Workflows/XamlTests.cs +++ b/src/Test/TestCases.Workflows/XamlTests.cs @@ -317,8 +317,8 @@ public VisualBasicInferTypeData() public void Should_compile_CSharp() { var compiler = new CSharpJitCompiler(new[] { typeof(Expression).Assembly, typeof(Enumerable).Assembly }.ToHashSet()); - var result = compiler.CompileExpression(new ExpressionToCompile("source.Select(s=>s).Sum()", new[] { "System", "System.Linq", "System.Linq.Expressions", "System.Collections.Generic" }, - name => name == "source" ? typeof(List) : null, typeof(int))); + var result = compiler.CompileExpression(new ExpressionToCompile("source.Select(s=>s).Sum()", ["System", "System.Linq", "System.Linq.Expressions", "System.Collections.Generic"], + (name, comparer )=> name == "source" ? typeof(List) : null, typeof(int))); ((Func, int>)result.Compile())(new List { 1, 2, 3 }).ShouldBe(6); } @@ -326,7 +326,7 @@ public void Should_compile_CSharp() public void Should_Fail_VBConversion() { var compiler = new VbJitCompiler(new[] { typeof(int).Assembly, typeof(Expression).Assembly, typeof(Conversions).Assembly }.ToHashSet()); - new Action(() => compiler.CompileExpression(new ExpressionToCompile("1", new[] { "System", "System.Linq", "System.Linq.Expressions" }, _ => typeof(int), typeof(string)))) + new Action(() => compiler.CompileExpression(new ExpressionToCompile("1", new[] { "System", "System.Linq", "System.Linq.Expressions" }, (_, __) => typeof(int), typeof(string)))) .ShouldThrow().Message.ShouldContain("BC30512: Option Strict On disallows implicit conversions"); } @@ -391,6 +391,7 @@ public class AheadOfTimeXamlTests : XamlTestsBase "; [Fact] public void CompileExpressionsDefault() => InvokeWorkflow(CSharpExpressions); + [Fact] public void CompileExpressionsWithCompiler() => new Action(() => ActivityXamlServices.Load(new StringReader(CSharpExpressions), @@ -414,6 +415,7 @@ public void CSharpCompileError() new Action(() => InvokeWorkflow(xaml)).ShouldThrow().Data.Values.Cast() .ShouldAllBe(error => error.Contains("error CS0103: The name 'constant' does not exist in the current context")); } + [Fact] public void SetCompiledExpressionRootForImplementation() { @@ -421,6 +423,7 @@ public void SetCompiledExpressionRootForImplementation() CompiledExpressionInvoker.SetCompiledExpressionRootForImplementation(writeLine, new Expressions()); WorkflowInvoker.Invoke(writeLine); } + [Fact] public void ValidateSkipCompilation() { @@ -428,6 +431,7 @@ public void ValidateSkipCompilation() var results = ActivityValidationServices.Validate(writeLine, new() { SkipExpressionCompilation = true }); results.Errors.ShouldBeEmpty(); } + [Fact] public void DuplicateVariable() { @@ -439,6 +443,7 @@ public void DuplicateVariable() var withMyVar = (WithMyVar)WorkflowInspectionServices.Resolve(root, "1.1"); ((ITextExpression)((Sequence)withMyVar.Body.Handler).Activities[0]).GetExpressionTree(); } + [Fact] public void CSharpInputOutput() { diff --git a/src/UiPath.Workflow/Activities/JitCompilerHelper.cs b/src/UiPath.Workflow/Activities/JitCompilerHelper.cs index b4ffdb09..bde15615 100644 --- a/src/UiPath.Workflow/Activities/JitCompilerHelper.cs +++ b/src/UiPath.Workflow/Activities/JitCompilerHelper.cs @@ -51,6 +51,7 @@ internal abstract class JitCompilerHelper private static readonly FindMatch s_delegateFindAllLocationReferenceMatch = FindAllLocationReferenceMatch; protected LocationReferenceEnvironment Environment; + protected abstract StringComparison StringComparison { get; } // this is a flag to differentiate the cached short-cut Rewrite from the normal post-compilation Rewrite protected bool IsShortCutRewrite; @@ -190,10 +191,10 @@ private static void ExtractNamespacesAndReferences(VisualBasicSettings vbSetting } private static bool FindLocationReferenceMatchShortcut(LocationReference reference, string targetName, - Type targetType, out bool terminateSearch) + Type targetType, StringComparison stringComparison, out bool terminateSearch) { terminateSearch = false; - if (string.Equals(reference.Name, targetName, StringComparison.OrdinalIgnoreCase)) + if (string.Equals(reference.Name, targetName, stringComparison)) { if (targetType != reference.Type) { @@ -207,11 +208,11 @@ private static bool FindLocationReferenceMatchShortcut(LocationReference referen return false; } - private static bool FindFirstLocationReferenceMatch(LocationReference reference, string targetName, Type targetType, + private static bool FindFirstLocationReferenceMatch(LocationReference reference, string targetName, Type targetType, StringComparison stringComparison, out bool terminateSearch) { terminateSearch = false; - if (string.Equals(reference.Name, targetName, StringComparison.OrdinalIgnoreCase)) + if (string.Equals(reference.Name, targetName, stringComparison)) { terminateSearch = true; return true; @@ -220,11 +221,11 @@ private static bool FindFirstLocationReferenceMatch(LocationReference reference, return false; } - private static bool FindAllLocationReferenceMatch(LocationReference reference, string targetName, Type targetType, + private static bool FindAllLocationReferenceMatch(LocationReference reference, string targetName, Type targetType, StringComparison stringComparison, out bool terminateSearch) { terminateSearch = false; - if (string.Equals(reference.Name, targetName, StringComparison.OrdinalIgnoreCase)) + if (string.Equals(reference.Name, targetName, stringComparison)) { return true; } @@ -434,6 +435,7 @@ protected Expression Rewrite(Expression expression, ReadOnlyCollection already } private static LocationReference FindLocationReferencesFromEnvironment(LocationReferenceEnvironment environment, - FindMatch findMatch, string targetName, Type targetType, out bool foundMultiple) + FindMatch findMatch, string targetName, Type targetType, StringComparison stringComparison, out bool foundMultiple) { var currentEnvironment = environment; foundMultiple = false; @@ -1091,7 +1093,7 @@ private static LocationReference FindLocationReferencesFromEnvironment(LocationR LocationReference toReturn = null; foreach (var reference in currentEnvironment.GetLocationReferences()) { - if (findMatch(reference, targetName, targetType, out var terminateSearch)) + if (findMatch(reference, targetName, targetType, stringComparison, out var terminateSearch)) { if (toReturn != null) { @@ -1119,7 +1121,7 @@ private static LocationReference FindLocationReferencesFromEnvironment(LocationR return null; } - private delegate bool FindMatch(LocationReference reference, string targetName, Type targetType, + private delegate bool FindMatch(LocationReference reference, string targetName, Type targetType, StringComparison stringComparison, out bool terminateSearch); // this is a place holder for LambdaExpression(raw Expression Tree) that is to be stored in the cache @@ -1192,12 +1194,12 @@ public ScriptAndTypeScope(LocationReferenceEnvironment environmentProvider) public string ErrorMessage { get; private set; } - public Type FindVariable(string name) + public Type FindVariable(string name, StringComparison stringComparison) { LocationReference referenceToReturn = null; var findMatch = s_delegateFindAllLocationReferenceMatch; referenceToReturn = - FindLocationReferencesFromEnvironment(_environmentProvider, findMatch, name, null, out var foundMultiple); + FindLocationReferencesFromEnvironment(_environmentProvider, findMatch, name, null, stringComparison, out var foundMultiple); if (referenceToReturn != null) { if (foundMultiple) @@ -1489,7 +1491,7 @@ public override LambdaExpression CompileNonGeneric(LocationReferenceEnvironment return Expression.Lambda(finalBody, lambda.Parameters); } - private ExpressionToCompile ExpressionToCompile(Func variableTypeGetter, Type lambdaReturnType) + private ExpressionToCompile ExpressionToCompile(Func variableTypeGetter, Type lambdaReturnType) { return new ExpressionToCompile(TextToCompile, NamespaceImports, variableTypeGetter, lambdaReturnType); } diff --git a/src/UiPath.Workflow/Activities/ScriptingJitCompiler.cs b/src/UiPath.Workflow/Activities/ScriptingJitCompiler.cs index 4ed9c479..35b446e0 100644 --- a/src/UiPath.Workflow/Activities/ScriptingJitCompiler.cs +++ b/src/UiPath.Workflow/Activities/ScriptingJitCompiler.cs @@ -29,7 +29,7 @@ public abstract class JustInTimeCompiler public record CompilerInput(string Code, IReadOnlyCollection ImportedNamespaces) { } public record ExpressionToCompile(string Code, IReadOnlyCollection ImportedNamespaces, - Func VariableTypeGetter, Type LambdaReturnType) + Func VariableTypeGetter, Type LambdaReturnType) : CompilerInput(Code, ImportedNamespaces) { } @@ -57,7 +57,7 @@ public override LambdaExpression CompileExpression(ExpressionToCompile expressio var identifiers = GetIdentifiers(syntaxTree); var resolvedIdentifiers = identifiers - .Select(name => (Name: name, Type: expressionToCompile.VariableTypeGetter(name))) + .Select(name => (Name: name, Type: expressionToCompile.VariableTypeGetter(name, CompilerHelper.IdentifierNameComparison))) .Where(var => var.Type != null) .ToArray(); var names = string.Join(CompilerHelper.Comma, resolvedIdentifiers.Select(var => var.Name)); diff --git a/src/UiPath.Workflow/Activities/Utils/CSharpCompilerHelper.cs b/src/UiPath.Workflow/Activities/Utils/CSharpCompilerHelper.cs index 1f247b96..03d8d644 100644 --- a/src/UiPath.Workflow/Activities/Utils/CSharpCompilerHelper.cs +++ b/src/UiPath.Workflow/Activities/Utils/CSharpCompilerHelper.cs @@ -23,6 +23,8 @@ public sealed class CSharpCompilerHelper : CompilerHelper public override StringComparer IdentifierNameComparer { get; } = StringComparer.Ordinal; + public override StringComparison IdentifierNameComparison { get; } = StringComparison.Ordinal; + public override string GetTypeName(Type type) => (string)s_typeNameFormatter.FormatTypeName(type, s_typeOptions); diff --git a/src/UiPath.Workflow/Activities/Utils/CompilerHelper.cs b/src/UiPath.Workflow/Activities/Utils/CompilerHelper.cs index d3ff3220..ccc53a6a 100644 --- a/src/UiPath.Workflow/Activities/Utils/CompilerHelper.cs +++ b/src/UiPath.Workflow/Activities/Utils/CompilerHelper.cs @@ -17,6 +17,8 @@ public abstract class CompilerHelper public abstract StringComparer IdentifierNameComparer { get; } + public abstract StringComparison IdentifierNameComparison { get; } + public abstract int IdentifierKind { get; } public abstract (string, string) DefineDelegate(string types); diff --git a/src/UiPath.Workflow/Activities/Utils/VBCompilerHelper.cs b/src/UiPath.Workflow/Activities/Utils/VBCompilerHelper.cs index 065bfcb5..77364b0c 100644 --- a/src/UiPath.Workflow/Activities/Utils/VBCompilerHelper.cs +++ b/src/UiPath.Workflow/Activities/Utils/VBCompilerHelper.cs @@ -17,6 +17,8 @@ public sealed class VBCompilerHelper : CompilerHelper public override StringComparer IdentifierNameComparer { get; } = StringComparer.OrdinalIgnoreCase; + public override StringComparison IdentifierNameComparison { get; } = StringComparison.OrdinalIgnoreCase; + public override VisualBasicParseOptions ScriptParseOptions { get; } = new VisualBasicParseOptions(kind: SourceCodeKind.Script, languageVersion: LanguageVersion.Latest); diff --git a/src/UiPath.Workflow/Microsoft/CSharp/Activities/CSharpDesignerHelper.cs b/src/UiPath.Workflow/Microsoft/CSharp/Activities/CSharpDesignerHelper.cs index d55d4385..dd66abde 100644 --- a/src/UiPath.Workflow/Microsoft/CSharp/Activities/CSharpDesignerHelper.cs +++ b/src/UiPath.Workflow/Microsoft/CSharp/Activities/CSharpDesignerHelper.cs @@ -14,12 +14,14 @@ namespace Microsoft.CSharp.Activities; internal class CSharpHelper : JitCompilerHelper { - public CSharpHelper(string expressionText, HashSet assemblyReferences, - HashSet namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { } + protected override StringComparison StringComparison => StringComparison.Ordinal; - protected override JustInTimeCompiler CreateCompiler(HashSet references) => + protected override JustInTimeCompiler CreateCompiler(HashSet references) => new CSharpJitCompiler(references); + public CSharpHelper(string expressionText, HashSet assemblyReferences, + HashSet namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { } + internal const string Language = "C#"; } diff --git a/src/UiPath.Workflow/Microsoft/CSharp/CSharpExpressionCompiler.cs b/src/UiPath.Workflow/Microsoft/CSharp/CSharpExpressionCompiler.cs index 0da3541a..e85ca6a8 100644 --- a/src/UiPath.Workflow/Microsoft/CSharp/CSharpExpressionCompiler.cs +++ b/src/UiPath.Workflow/Microsoft/CSharp/CSharpExpressionCompiler.cs @@ -43,7 +43,7 @@ protected override SyntaxTree GetSyntaxTreeForExpression(string expression, bool var identifiers = syntaxTree.GetRoot().DescendantNodesAndSelf().Where(n => n.RawKind == (int)SyntaxKind.IdentifierName) .Select(n => n.ToString()).Distinct(_compilerHelper.IdentifierNameComparer); var resolvedIdentifiers = identifiers - .Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name))) + .Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name, _compilerHelper.IdentifierNameComparison))) .Where(var => var.Type != null) .ToArray(); diff --git a/src/UiPath.Workflow/Microsoft/VisualBasic/Activities/VisualBasicDesignerHelper.cs b/src/UiPath.Workflow/Microsoft/VisualBasic/Activities/VisualBasicDesignerHelper.cs index 54f90ca8..5237d3cf 100644 --- a/src/UiPath.Workflow/Microsoft/VisualBasic/Activities/VisualBasicDesignerHelper.cs +++ b/src/UiPath.Workflow/Microsoft/VisualBasic/Activities/VisualBasicDesignerHelper.cs @@ -14,6 +14,8 @@ namespace Microsoft.VisualBasic.Activities; internal class VisualBasicHelper : JitCompilerHelper { + protected override StringComparison StringComparison => StringComparison.OrdinalIgnoreCase; + public VisualBasicHelper(string expressionText, HashSet assemblyReferences, HashSet namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { } diff --git a/src/UiPath.Workflow/Microsoft/VisualBasic/VisualBasicExpressionCompiler.cs b/src/UiPath.Workflow/Microsoft/VisualBasic/VisualBasicExpressionCompiler.cs index d8263dfa..81e0c55f 100644 --- a/src/UiPath.Workflow/Microsoft/VisualBasic/VisualBasicExpressionCompiler.cs +++ b/src/UiPath.Workflow/Microsoft/VisualBasic/VisualBasicExpressionCompiler.cs @@ -45,7 +45,7 @@ protected override SyntaxTree GetSyntaxTreeForExpression(string expression, bool var identifiers = syntaxTree.GetRoot().DescendantNodesAndSelf().Where(n => n.RawKind == (int)SyntaxKind.IdentifierName) .Select(n => n.ToString()).Distinct(_compilerHelper.IdentifierNameComparer); var resolvedIdentifiers = identifiers - .Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name))) + .Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name, _compilerHelper.IdentifierNameComparison))) .Where(var => var.Type != null) .ToArray(); diff --git a/src/UiPath.Workflow/Validation/RoslynExpressionValidator.cs b/src/UiPath.Workflow/Validation/RoslynExpressionValidator.cs index fcc6a9ef..5844222c 100644 --- a/src/UiPath.Workflow/Validation/RoslynExpressionValidator.cs +++ b/src/UiPath.Workflow/Validation/RoslynExpressionValidator.cs @@ -241,7 +241,7 @@ private void AddExpressionToValidate(ExpressionToValidate expressionToValidate, .Select(n => n.ToString()).Distinct(CompilerHelper.IdentifierNameComparer); var resolvedIdentifiers = identifiers - .Select(name => (Name: name, Type: new ScriptAndTypeScope(expressionToValidate.Environment).FindVariable(name))) + .Select(name => (Name: name, Type: new ScriptAndTypeScope(expressionToValidate.Environment).FindVariable(name, CompilerHelper.IdentifierNameComparison))) .Where(var => var.Type != null) .ToArray();