Skip to content

Commit

Permalink
Support dictionary literals using infix operator (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
evinyang authored and nicklockwood committed Jan 4, 2024
1 parent 81e8f21 commit 0963d56
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
21 changes: 20 additions & 1 deletion Sources/AnyExpression.swift
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,24 @@ public struct AnyExpression: CustomStringConvertible {
}
}
case .function("[]", _):
return { box.store($0.map(box.load)) }
return { args in
let args = args.map(box.load)
let keyVals = args.compactMap({ $0 as? Dictionary<AnyHashable, Any>.Element })
return box.store(
args.isEmpty || args.count != keyVals.count
? args
: keyVals.reduce(into: [AnyHashable: Any](), { $0[$1.key] = $1.value })
)
}
case .infix(":"):
return { args in
switch (box.load(args[0]), box.load(args[1])) {
case let (lhs as AnyHashable, rhs):
return box.store(Dictionary<AnyHashable, Any>.Element(key: lhs, value: rhs))
case let (lhs, rhs):
throw Error.typeMismatch(symbol, [lhs, rhs])
}
}
case let .variable(name):
guard let string = unwrapString(name) else {
return { _ in throw Error.undefinedSymbol(symbol) }
Expand Down Expand Up @@ -608,6 +625,8 @@ extension AnyExpression.Error {
}
case .infix("==") where types.count == 2 && types[0] == types[1]:
return .message("Arguments for \(symbol) must conform to the Hashable protocol")
case .infix(":") where types.count == 2 && !(args[0] is AnyHashable):
return .message("First argument for \(symbol) must conform to the Hashable protocol")
case _ where types.count == 1:
return .message("Argument of type \(types[0]) is not compatible with \(symbol)")
default:
Expand Down
42 changes: 42 additions & 0 deletions Tests/AnyExpressionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,48 @@ class AnyExpressionTests: XCTestCase {
}
}

func testStringDictionaryLiteral() {
let expression = AnyExpression("['a': 1, 'b': 2.5, 'c': 3]")
XCTAssertEqual(try expression.evaluate(), ["a": 1, "b": 2.5, "c": 3])
}

func testDoubleDictionaryLiteral() {
let expression = AnyExpression("[1.5: false, 2.0: nil, 3.5: true]")
XCTAssertEqual(try expression.evaluate(), [1.5: false, 2.0: nil, 3.5: true])
}

func testIntDictionaryLiteral() {
let expression = AnyExpression("[1: 'f', 2: 'e', 3: 'd']")
XCTAssertEqual(try expression.evaluate(), [1: "f", 2: "e", 3: "d"])
}

func testDictionaryLiteralWithNonHashableKey() {
let expression = AnyExpression("[nil: false]")
XCTAssertThrowsError(try expression.evaluate() as Any) { error in
XCTAssertEqual(error as? Expression.Error, .typeMismatch(.infix(":"), [nil as Any? as Any, false]))
}
}

func testSubscriptStringDictionaryLiteralWithString() {
let expression = AnyExpression("['a': 1, 'b': 2.5, 'c': 3]['c']")
XCTAssertEqual(try expression.evaluate(), 3)
}

func testSubscriptDoubleDictionaryLiteralWithInt() {
let expression = AnyExpression("[1.5: false, 2.0: nil, 3.5: true][2]")
XCTAssertEqual(try expression.evaluate(), Optional<Bool>.none)
}

func testSubscriptIntDictionaryLiteralWithDouble() {
let expression = AnyExpression("[1: 'f', 2: 'e', 3: 'd'][1.0]")
XCTAssertEqual(try expression.evaluate(), "f")
}

func testSubscriptDictionaryLiteralWithNonexistentKey() {
let expression = AnyExpression("[1: 'f', 2: 'e', 3: 'd']['d']")
XCTAssertEqual(try expression.evaluate(), Optional<String>.none)
}

// MARK: Ranges

func testClosedIntRange() {
Expand Down

0 comments on commit 0963d56

Please sign in to comment.