diff --git a/runtime/interpreter/interpreter_expression.go b/runtime/interpreter/interpreter_expression.go index 330610d2da..d2894749bf 100644 --- a/runtime/interpreter/interpreter_expression.go +++ b/runtime/interpreter/interpreter_expression.go @@ -1335,6 +1335,13 @@ func (interpreter *Interpreter) VisitCastingExpression(expression *ast.CastingEx // thus this is the only place where it becomes necessary to "instantiate" the result of a map to its // concrete outputs. In other places (e.g. interface conformance checks) we want to leave maps generic, // so we don't substitute them. + + // if the target is anystruct or anyresource we want to preserve optionals + unboxedExpectedType := sema.UnwrapOptionalType(expectedType) + if !(unboxedExpectedType == sema.AnyStructType || unboxedExpectedType == sema.AnyResourceType) { + // otherwise dynamic cast now always unboxes optionals + value = interpreter.Unbox(locationRange, value) + } valueSemaType := interpreter.SubstituteMappedEntitlements(interpreter.MustSemaTypeOfValue(value)) valueStaticType := ConvertSemaToStaticType(interpreter, valueSemaType) isSubType := interpreter.IsSubTypeOfSemaType(valueStaticType, expectedType) diff --git a/runtime/tests/interpreter/dynamic_casting_test.go b/runtime/tests/interpreter/dynamic_casting_test.go index 933e4f12f8..2a3a96292f 100644 --- a/runtime/tests/interpreter/dynamic_casting_test.go +++ b/runtime/tests/interpreter/dynamic_casting_test.go @@ -3870,3 +3870,409 @@ func TestInterpretDynamicCastingReferenceCasting(t *testing.T) { } }) } +func TestInterpretDynamicCastingOptionalUnwrapping(t *testing.T) { + + t.Parallel() + + t.Run("as?", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int? = 42 + let y: Int? = x as? Int + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying(interpreter.NewUnmeteredIntValueFromInt64(42)), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int? = 42 + let y: Int = x as! Int + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(42), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("multi optional as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: Int = x as! Int + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredIntValueFromInt64(42), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("multi optional as?", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: Int? = x as? Int + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying(interpreter.NewUnmeteredIntValueFromInt64(42)), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("nil as?", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int? = nil + let y: Int? = x as? Int + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.Nil, + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("nil as!", func(t *testing.T) { + t.Parallel() + + code := ` + fun test() { + let x: Int? = nil + let y: Int = x as! Int + } + ` + + inter := parseCheckAndInterpret(t, code) + + _, err := inter.Invoke("test") + RequireError(t, err) + + assert.ErrorAs(t, err, &interpreter.ForceCastTypeMismatchError{}) + }) + + t.Run("string as!", func(t *testing.T) { + t.Parallel() + + code := ` + fun test(): String { + let hello: String???? = "hello" + let something: AnyStruct = hello + return something as! String + } + ` + + inter := parseCheckAndInterpret(t, code) + + result, err := inter.Invoke("test") + require.NoError(t, err) + require.Equal(t, interpreter.NewUnmeteredStringValue("hello"), result) + }) + + t.Run("return nested optional as?", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: Int?? = x as? Int? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("return optional as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: Int? = x as! Int? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("return nested optional as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: Int?? = x as! Int?? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("return optional as?", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: Int??? = x as? Int?? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + ), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("AnyResource as!", func(t *testing.T) { + t.Parallel() + + code := ` + resource R {} + + let x: @R? <- create R() + let y: @AnyResource <- x as! @AnyResource + ` + + inter := parseCheckAndInterpret(t, code) + + value := inter.Globals.Get("y").GetValue(inter) + + require.IsType(t, + &interpreter.SomeValue{}, + value, + ) + + require.IsType(t, + &interpreter.CompositeValue{}, + value.(*interpreter.SomeValue). + InnerValue(inter, interpreter.EmptyLocationRange), + ) + }) + + t.Run("resource cast as!", func(t *testing.T) { + t.Parallel() + + code := ` + resource R {} + + let x: @R?? <- create R() + let y: @R <- x as! @R + ` + + inter := parseCheckAndInterpret(t, code) + + value := inter.Globals.Get("y").GetValue(inter) + + require.IsType(t, + &interpreter.CompositeValue{}, + value, + ) + }) + + t.Run("resource cast as?", func(t *testing.T) { + t.Parallel() + + code := ` + resource R {} + + fun test(): @R? { + + let x: @R? <- create R() + + if let z <- x as? @R { + return <-z + } else { + destroy x + return nil + } + } + + ` + + inter := parseCheckAndInterpret(t, code) + + result, err := inter.Invoke("test") + require.NoError(t, err) + require.IsType(t, + &interpreter.SomeValue{}, + result, + ) + + require.IsType(t, + &interpreter.CompositeValue{}, + result.(*interpreter.SomeValue). + InnerValue(inter, interpreter.EmptyLocationRange), + ) + + }) + + t.Run("resource cast AnyResource as?", func(t *testing.T) { + t.Parallel() + + code := ` + resource R {} + + fun test(): @AnyResource? { + + let x: @R? <- create R() + + if let z <- x as? @AnyResource { + return <-z + } else { + destroy x + return nil + } + } + + ` + + inter := parseCheckAndInterpret(t, code) + + result, err := inter.Invoke("test") + require.NoError(t, err) + require.IsType(t, + &interpreter.SomeValue{}, + result, + ) + + require.IsType(t, + &interpreter.CompositeValue{}, + result.(*interpreter.SomeValue). + InnerValue(inter, interpreter.EmptyLocationRange), + ) + + }) + + t.Run("AnyStruct boxing as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int? = 42 + let y: AnyStruct??? = x as! AnyStruct?? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + ), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("AnyStruct unboxing as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: Int??? = 42 + let y: AnyStruct = x as! AnyStruct?? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + ), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) + + t.Run("AnyStruct cast to Int? as!", func(t *testing.T) { + t.Parallel() + + code := ` + let x: AnyStruct = 42 + let y: Int? = x as! Int? + ` + + inter := parseCheckAndInterpret(t, code) + + AssertValuesEqual( + t, + inter, + interpreter.NewUnmeteredSomeValueNonCopying( + interpreter.NewUnmeteredIntValueFromInt64(42), + ), + inter.Globals.Get("y").GetValue(inter), + ) + }) +}