Skip to content

Commit

Permalink
Fix validation for c# not considering variable name is case sensitive
Browse files Browse the repository at this point in the history
  • Loading branch information
aoltean16 committed Apr 12, 2024
1 parent afc8210 commit fdf7a60
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 31 deletions.
31 changes: 31 additions & 0 deletions src/Test/TestCases.Workflows/ExpressionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -535,4 +535,35 @@ public void CSharpReferenceTypeIsCheckedForGenerics()
var result = ActivityValidationServices.Validate(sequence, _useValidator);
result.Errors.ShouldBeEmpty();
}

[Fact]
public void CS_IsCaseSensitive_WhenSearchingForVariable_Throws()
{
var seq = new Sequence();
seq.Variables.Add(new Variable<string>("ABC"));
seq.Activities.Add(new WriteLine { Text = new InArgument<string>(new CSharpValue<string>("abc")) });
var valid = ActivityValidationServices.Validate(seq, _useValidator);
valid.Errors.Count.ShouldBe(1);
valid.Errors.First().Message.ShouldBe("CS0103: The name 'abc' does not exist in the current context");
}

[Fact]
public void CS_IsCaseSensitive_WhenSearchingForVariable_Succeeds()
{
var seq = new Sequence();
seq.Variables.Add(new Variable<string>("ABC"));
seq.Activities.Add(new WriteLine { Text = new InArgument<string>(new CSharpValue<string>("ABC")) });
var valid = ActivityValidationServices.Validate(seq, _useValidator);
valid.Errors.ShouldBeEmpty();
}

[Fact]
public void VB_IsCaseInsensitive()
{
var seq = new Sequence();
seq.Variables.Add(new Variable<string>("ABC"));
seq.Activities.Add(new WriteLine { Text = new InArgument<string>(new VisualBasicValue<string>("abc")) });
var valid = ActivityValidationServices.Validate(seq, _useValidator);
valid.Errors.ShouldBeEmpty();
}
}
16 changes: 8 additions & 8 deletions src/Test/TestCases.Workflows/JitCompilerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public void VisualBasicJitCompiler_FieldAccess()
[Fact]
public void VisualBasicJitCompiler_PropertyAccess_SameNameAsVariable()
{
static Type VariableTypeGetter(string name)
static Type VariableTypeGetter(string name, StringComparison stringComparison)
{
return name switch
{
Expand All @@ -69,7 +69,7 @@ static Type VariableTypeGetter(string name)
[InlineData(17)]
public void VisualBasicJitCompiler_ExpressionWithMultipleVariablesVariables(int noOfVar)
{
static Type VariableTypeGetter(string name)
static Type VariableTypeGetter(string name, StringComparison stringComparison)
{
return name switch
{
Expand All @@ -85,7 +85,7 @@ static Type VariableTypeGetter(string name)
[Fact]
public void CSharpJitCompiler_PropertyAccess()
{
static Type VariableTypeGetter(string name)
static Type VariableTypeGetter(string name, StringComparison stringComparison)
{
return name switch
{
Expand All @@ -108,7 +108,7 @@ static Type VariableTypeGetter(string name)
[InlineData(17)]
public void CSharpJitCompiler_ExpressionWithMultipleVariablesVariables(int noOfVar)
{
static Type VariableTypeGetter(string name)
static Type VariableTypeGetter(string name, StringComparison stringComparison)
{
return name switch
{
Expand All @@ -125,7 +125,7 @@ static Type VariableTypeGetter(string name)
public void VbExpression_UndeclaredObject()
{
var expressionToCompile = "new UndeclaredClass()";
var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s)=>null, typeof(object)));
var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s, c)=>null, typeof(object)));

Assert.ThrowsAny<SourceExpressionException>(sut);
}
Expand All @@ -134,19 +134,19 @@ public void VbExpression_UndeclaredObject()
public void VbExpression_WithObjectInitializer()
{
var expressionToCompile = "new TestIndexerClass() With {.Field=\"1\"}";
var result = _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s)=>null, typeof(TestIndexerClass)));
var result = _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s, c)=>null, typeof(TestIndexerClass)));
result.ReturnType.ShouldBe(typeof(TestIndexerClass));
}

[Fact]
public void VbExpression_UndeclaredObjectWithObjectInitializer()
{
var expressionToCompile = "new UndeclaredClass() With {.Field=\"1\"}";
var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s)=>null, typeof(object)));
var sut = () => _vbJitCompiler.CompileExpression(new ExpressionToCompile(expressionToCompile, _namespaces, (s, c)=>null, typeof(object)));

Assert.ThrowsAny<SourceExpressionException>(sut);
}
private static Type VariableTypeGetter(string name)
private static Type VariableTypeGetter(string name, StringComparison stringComparison)
=> name switch
{
"testIndexerClass" => typeof(TestIndexerClass),
Expand Down
11 changes: 8 additions & 3 deletions src/Test/TestCases.Workflows/XamlTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,16 @@ public VisualBasicInferTypeData()
public void Should_compile_CSharp()
{
var compiler = new CSharpJitCompiler(new[] { typeof(Expression).Assembly, typeof(Enumerable).Assembly }.ToHashSet());
var result = compiler.CompileExpression(new ExpressionToCompile("source.Select(s=>s).Sum()", new[] { "System", "System.Linq", "System.Linq.Expressions", "System.Collections.Generic" },
name => name == "source" ? typeof(List<int>) : null, typeof(int)));
var result = compiler.CompileExpression(new ExpressionToCompile("source.Select(s=>s).Sum()", ["System", "System.Linq", "System.Linq.Expressions", "System.Collections.Generic"],
(name, comparer )=> name == "source" ? typeof(List<int>) : null, typeof(int)));
((Func<List<int>, int>)result.Compile())(new List<int> { 1, 2, 3 }).ShouldBe(6);
}

[Fact]
public void Should_Fail_VBConversion()
{
var compiler = new VbJitCompiler(new[] { typeof(int).Assembly, typeof(Expression).Assembly, typeof(Conversions).Assembly }.ToHashSet());
new Action(() => compiler.CompileExpression(new ExpressionToCompile("1", new[] { "System", "System.Linq", "System.Linq.Expressions" }, _ => typeof(int), typeof(string))))
new Action(() => compiler.CompileExpression(new ExpressionToCompile("1", new[] { "System", "System.Linq", "System.Linq.Expressions" }, (_, __) => typeof(int), typeof(string))))
.ShouldThrow<SourceExpressionException>().Message.ShouldContain("BC30512: Option Strict On disallows implicit conversions");
}

Expand Down Expand Up @@ -391,6 +391,7 @@ public class AheadOfTimeXamlTests : XamlTestsBase
</Activity>";
[Fact]
public void CompileExpressionsDefault() => InvokeWorkflow(CSharpExpressions);

[Fact]
public void CompileExpressionsWithCompiler() =>
new Action(() => ActivityXamlServices.Load(new StringReader(CSharpExpressions),
Expand All @@ -414,20 +415,23 @@ public void CSharpCompileError()
new Action(() => InvokeWorkflow(xaml)).ShouldThrow<InvalidOperationException>().Data.Values.Cast<string>()
.ShouldAllBe(error => error.Contains("error CS0103: The name 'constant' does not exist in the current context"));
}

[Fact]
public void SetCompiledExpressionRootForImplementation()
{
var writeLine = new WriteLine { Text = new InArgument<string>(new VisualBasicValue<string>("[s]")) };
CompiledExpressionInvoker.SetCompiledExpressionRootForImplementation(writeLine, new Expressions());
WorkflowInvoker.Invoke(writeLine);
}

[Fact]
public void ValidateSkipCompilation()
{
var writeLine = new WriteLine { Text = new InArgument<string>(new VisualBasicValue<string>("[s]")) };
var results = ActivityValidationServices.Validate(writeLine, new() { SkipExpressionCompilation = true });
results.Errors.ShouldBeEmpty();
}

[Fact]
public void DuplicateVariable()
{
Expand All @@ -439,6 +443,7 @@ public void DuplicateVariable()
var withMyVar = (WithMyVar)WorkflowInspectionServices.Resolve(root, "1.1");
((ITextExpression)((Sequence)withMyVar.Body.Handler).Activities[0]).GetExpressionTree();
}

[Fact]
public void CSharpInputOutput()
{
Expand Down
26 changes: 14 additions & 12 deletions src/UiPath.Workflow/Activities/JitCompilerHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ internal abstract class JitCompilerHelper
private static readonly FindMatch s_delegateFindAllLocationReferenceMatch = FindAllLocationReferenceMatch;

protected LocationReferenceEnvironment Environment;
protected abstract StringComparison StringComparison { get; }

// this is a flag to differentiate the cached short-cut Rewrite from the normal post-compilation Rewrite
protected bool IsShortCutRewrite;
Expand Down Expand Up @@ -190,10 +191,10 @@ private static void ExtractNamespacesAndReferences(VisualBasicSettings vbSetting
}

private static bool FindLocationReferenceMatchShortcut(LocationReference reference, string targetName,
Type targetType, out bool terminateSearch)
Type targetType, StringComparison stringComparison, out bool terminateSearch)
{
terminateSearch = false;
if (string.Equals(reference.Name, targetName, StringComparison.OrdinalIgnoreCase))
if (string.Equals(reference.Name, targetName, stringComparison))
{
if (targetType != reference.Type)
{
Expand All @@ -207,11 +208,11 @@ private static bool FindLocationReferenceMatchShortcut(LocationReference referen
return false;
}

private static bool FindFirstLocationReferenceMatch(LocationReference reference, string targetName, Type targetType,
private static bool FindFirstLocationReferenceMatch(LocationReference reference, string targetName, Type targetType, StringComparison stringComparison,
out bool terminateSearch)
{
terminateSearch = false;
if (string.Equals(reference.Name, targetName, StringComparison.OrdinalIgnoreCase))
if (string.Equals(reference.Name, targetName, stringComparison))
{
terminateSearch = true;
return true;
Expand All @@ -220,11 +221,11 @@ private static bool FindFirstLocationReferenceMatch(LocationReference reference,
return false;
}

private static bool FindAllLocationReferenceMatch(LocationReference reference, string targetName, Type targetType,
private static bool FindAllLocationReferenceMatch(LocationReference reference, string targetName, Type targetType, StringComparison stringComparison,
out bool terminateSearch)
{
terminateSearch = false;
if (string.Equals(reference.Name, targetName, StringComparison.OrdinalIgnoreCase))
if (string.Equals(reference.Name, targetName, stringComparison))
{
return true;
}
Expand Down Expand Up @@ -434,6 +435,7 @@ protected Expression Rewrite(Expression expression, ReadOnlyCollection<Parameter
findMatch,
variableExpression.Name,
variableExpression.Type,
StringComparison,
out var foundMultiple);

if (finalReference != null && !foundMultiple)
Expand Down Expand Up @@ -1082,7 +1084,7 @@ private static void EnsureTypeReferencedRecurse(Type type, HashSet<Type> already
}

private static LocationReference FindLocationReferencesFromEnvironment(LocationReferenceEnvironment environment,
FindMatch findMatch, string targetName, Type targetType, out bool foundMultiple)
FindMatch findMatch, string targetName, Type targetType, StringComparison stringComparison, out bool foundMultiple)
{
var currentEnvironment = environment;
foundMultiple = false;
Expand All @@ -1091,7 +1093,7 @@ private static LocationReference FindLocationReferencesFromEnvironment(LocationR
LocationReference toReturn = null;
foreach (var reference in currentEnvironment.GetLocationReferences())
{
if (findMatch(reference, targetName, targetType, out var terminateSearch))
if (findMatch(reference, targetName, targetType, stringComparison, out var terminateSearch))
{
if (toReturn != null)
{
Expand Down Expand Up @@ -1119,7 +1121,7 @@ private static LocationReference FindLocationReferencesFromEnvironment(LocationR
return null;
}

private delegate bool FindMatch(LocationReference reference, string targetName, Type targetType,
private delegate bool FindMatch(LocationReference reference, string targetName, Type targetType, StringComparison stringComparison,
out bool terminateSearch);

// this is a place holder for LambdaExpression(raw Expression Tree) that is to be stored in the cache
Expand Down Expand Up @@ -1192,12 +1194,12 @@ public ScriptAndTypeScope(LocationReferenceEnvironment environmentProvider)

public string ErrorMessage { get; private set; }

public Type FindVariable(string name)
public Type FindVariable(string name, StringComparison stringComparison)
{
LocationReference referenceToReturn = null;
var findMatch = s_delegateFindAllLocationReferenceMatch;
referenceToReturn =
FindLocationReferencesFromEnvironment(_environmentProvider, findMatch, name, null, out var foundMultiple);
FindLocationReferencesFromEnvironment(_environmentProvider, findMatch, name, null, stringComparison, out var foundMultiple);
if (referenceToReturn != null)
{
if (foundMultiple)
Expand Down Expand Up @@ -1489,7 +1491,7 @@ public override LambdaExpression CompileNonGeneric(LocationReferenceEnvironment
return Expression.Lambda(finalBody, lambda.Parameters);
}

private ExpressionToCompile ExpressionToCompile(Func<string, Type> variableTypeGetter, Type lambdaReturnType)
private ExpressionToCompile ExpressionToCompile(Func<string, StringComparison, Type> variableTypeGetter, Type lambdaReturnType)
{
return new ExpressionToCompile(TextToCompile, NamespaceImports, variableTypeGetter, lambdaReturnType);
}
Expand Down
4 changes: 2 additions & 2 deletions src/UiPath.Workflow/Activities/ScriptingJitCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public abstract class JustInTimeCompiler
public record CompilerInput(string Code, IReadOnlyCollection<string> ImportedNamespaces) { }

public record ExpressionToCompile(string Code, IReadOnlyCollection<string> ImportedNamespaces,
Func<string, Type> VariableTypeGetter, Type LambdaReturnType)
Func<string, StringComparison, Type> VariableTypeGetter, Type LambdaReturnType)
: CompilerInput(Code, ImportedNamespaces)
{ }

Expand Down Expand Up @@ -57,7 +57,7 @@ public override LambdaExpression CompileExpression(ExpressionToCompile expressio
var identifiers = GetIdentifiers(syntaxTree);
var resolvedIdentifiers =
identifiers
.Select(name => (Name: name, Type: expressionToCompile.VariableTypeGetter(name)))
.Select(name => (Name: name, Type: expressionToCompile.VariableTypeGetter(name, CompilerHelper.IdentifierNameComparison)))
.Where(var => var.Type != null)
.ToArray();
var names = string.Join(CompilerHelper.Comma, resolvedIdentifiers.Select(var => var.Name));
Expand Down
2 changes: 2 additions & 0 deletions src/UiPath.Workflow/Activities/Utils/CSharpCompilerHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ public sealed class CSharpCompilerHelper : CompilerHelper

public override StringComparer IdentifierNameComparer { get; } = StringComparer.Ordinal;

public override StringComparison IdentifierNameComparison { get; } = StringComparison.Ordinal;

public override string GetTypeName(Type type) =>
(string)s_typeNameFormatter.FormatTypeName(type, s_typeOptions);

Expand Down
2 changes: 2 additions & 0 deletions src/UiPath.Workflow/Activities/Utils/CompilerHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public abstract class CompilerHelper

public abstract StringComparer IdentifierNameComparer { get; }

public abstract StringComparison IdentifierNameComparison { get; }

public abstract int IdentifierKind { get; }

public abstract (string, string) DefineDelegate(string types);
Expand Down
2 changes: 2 additions & 0 deletions src/UiPath.Workflow/Activities/Utils/VBCompilerHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ public sealed class VBCompilerHelper : CompilerHelper

public override StringComparer IdentifierNameComparer { get; } = StringComparer.OrdinalIgnoreCase;

public override StringComparison IdentifierNameComparison { get; } = StringComparison.OrdinalIgnoreCase;

public override VisualBasicParseOptions ScriptParseOptions { get; } = new VisualBasicParseOptions(kind: SourceCodeKind.Script, languageVersion: LanguageVersion.Latest);


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ namespace Microsoft.CSharp.Activities;

internal class CSharpHelper : JitCompilerHelper<CSharpHelper>
{
public CSharpHelper(string expressionText, HashSet<AssemblyReference> assemblyReferences,
HashSet<string> namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { }
protected override StringComparison StringComparison => StringComparison.Ordinal;

protected override JustInTimeCompiler CreateCompiler(HashSet<Assembly> references) =>
protected override JustInTimeCompiler CreateCompiler(HashSet<Assembly> references) =>
new CSharpJitCompiler(references);

public CSharpHelper(string expressionText, HashSet<AssemblyReference> assemblyReferences,
HashSet<string> namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { }

internal const string Language = "C#";
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ protected override SyntaxTree GetSyntaxTreeForExpression(string expression, bool
var identifiers = syntaxTree.GetRoot().DescendantNodesAndSelf().Where(n => n.RawKind == (int)SyntaxKind.IdentifierName)
.Select(n => n.ToString()).Distinct(_compilerHelper.IdentifierNameComparer);
var resolvedIdentifiers = identifiers
.Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name)))
.Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name, _compilerHelper.IdentifierNameComparison)))
.Where(var => var.Type != null)
.ToArray();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ namespace Microsoft.VisualBasic.Activities;

internal class VisualBasicHelper : JitCompilerHelper<VisualBasicHelper>
{
protected override StringComparison StringComparison => StringComparison.OrdinalIgnoreCase;

public VisualBasicHelper(string expressionText, HashSet<AssemblyReference> assemblyReferences,
HashSet<string> namespaceImportsNames) : base(expressionText, assemblyReferences, namespaceImportsNames) { }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ protected override SyntaxTree GetSyntaxTreeForExpression(string expression, bool
var identifiers = syntaxTree.GetRoot().DescendantNodesAndSelf().Where(n => n.RawKind == (int)SyntaxKind.IdentifierName)
.Select(n => n.ToString()).Distinct(_compilerHelper.IdentifierNameComparer);
var resolvedIdentifiers = identifiers
.Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name)))
.Select(name => (Name: name, Type: new ScriptAndTypeScope(environment).FindVariable(name, _compilerHelper.IdentifierNameComparison)))
.Where(var => var.Type != null)
.ToArray();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ private void AddExpressionToValidate(ExpressionToValidate expressionToValidate,
.Select(n => n.ToString()).Distinct(CompilerHelper.IdentifierNameComparer);
var resolvedIdentifiers =
identifiers
.Select(name => (Name: name, Type: new ScriptAndTypeScope(expressionToValidate.Environment).FindVariable(name)))
.Select(name => (Name: name, Type: new ScriptAndTypeScope(expressionToValidate.Environment).FindVariable(name, CompilerHelper.IdentifierNameComparison)))
.Where(var => var.Type != null)
.ToArray();

Expand Down

0 comments on commit fdf7a60

Please sign in to comment.