From 0963d562e684b8ce1e8b5ffbbf93c1c27bba9471 Mon Sep 17 00:00:00 2001 From: evinyang Date: Sat, 13 Aug 2022 07:11:43 -0700 Subject: [PATCH] Support dictionary literals using infix operator (#40) --- Sources/AnyExpression.swift | 21 ++++++++++++++++- Tests/AnyExpressionTests.swift | 42 ++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/Sources/AnyExpression.swift b/Sources/AnyExpression.swift index 7ef6a5b..4064029 100644 --- a/Sources/AnyExpression.swift +++ b/Sources/AnyExpression.swift @@ -401,7 +401,24 @@ public struct AnyExpression: CustomStringConvertible { } } case .function("[]", _): - return { box.store($0.map(box.load)) } + return { args in + let args = args.map(box.load) + let keyVals = args.compactMap({ $0 as? Dictionary.Element }) + return box.store( + args.isEmpty || args.count != keyVals.count + ? args + : keyVals.reduce(into: [AnyHashable: Any](), { $0[$1.key] = $1.value }) + ) + } + case .infix(":"): + return { args in + switch (box.load(args[0]), box.load(args[1])) { + case let (lhs as AnyHashable, rhs): + return box.store(Dictionary.Element(key: lhs, value: rhs)) + case let (lhs, rhs): + throw Error.typeMismatch(symbol, [lhs, rhs]) + } + } case let .variable(name): guard let string = unwrapString(name) else { return { _ in throw Error.undefinedSymbol(symbol) } @@ -608,6 +625,8 @@ extension AnyExpression.Error { } case .infix("==") where types.count == 2 && types[0] == types[1]: return .message("Arguments for \(symbol) must conform to the Hashable protocol") + case .infix(":") where types.count == 2 && !(args[0] is AnyHashable): + return .message("First argument for \(symbol) must conform to the Hashable protocol") case _ where types.count == 1: return .message("Argument of type \(types[0]) is not compatible with \(symbol)") default: diff --git a/Tests/AnyExpressionTests.swift b/Tests/AnyExpressionTests.swift index d8a758b..5fc1202 100644 --- a/Tests/AnyExpressionTests.swift +++ b/Tests/AnyExpressionTests.swift @@ -398,6 +398,48 @@ class AnyExpressionTests: XCTestCase { } } + func testStringDictionaryLiteral() { + let expression = AnyExpression("['a': 1, 'b': 2.5, 'c': 3]") + XCTAssertEqual(try expression.evaluate(), ["a": 1, "b": 2.5, "c": 3]) + } + + func testDoubleDictionaryLiteral() { + let expression = AnyExpression("[1.5: false, 2.0: nil, 3.5: true]") + XCTAssertEqual(try expression.evaluate(), [1.5: false, 2.0: nil, 3.5: true]) + } + + func testIntDictionaryLiteral() { + let expression = AnyExpression("[1: 'f', 2: 'e', 3: 'd']") + XCTAssertEqual(try expression.evaluate(), [1: "f", 2: "e", 3: "d"]) + } + + func testDictionaryLiteralWithNonHashableKey() { + let expression = AnyExpression("[nil: false]") + XCTAssertThrowsError(try expression.evaluate() as Any) { error in + XCTAssertEqual(error as? Expression.Error, .typeMismatch(.infix(":"), [nil as Any? as Any, false])) + } + } + + func testSubscriptStringDictionaryLiteralWithString() { + let expression = AnyExpression("['a': 1, 'b': 2.5, 'c': 3]['c']") + XCTAssertEqual(try expression.evaluate(), 3) + } + + func testSubscriptDoubleDictionaryLiteralWithInt() { + let expression = AnyExpression("[1.5: false, 2.0: nil, 3.5: true][2]") + XCTAssertEqual(try expression.evaluate(), Optional.none) + } + + func testSubscriptIntDictionaryLiteralWithDouble() { + let expression = AnyExpression("[1: 'f', 2: 'e', 3: 'd'][1.0]") + XCTAssertEqual(try expression.evaluate(), "f") + } + + func testSubscriptDictionaryLiteralWithNonexistentKey() { + let expression = AnyExpression("[1: 'f', 2: 'e', 3: 'd']['d']") + XCTAssertEqual(try expression.evaluate(), Optional.none) + } + // MARK: Ranges func testClosedIntRange() {