Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bug: Projected winrt::consume_ methods will nullptr crash if the underlying QueryInterface call fails #1442

Merged
merged 9 commits into from
Oct 30, 2024
16 changes: 13 additions & 3 deletions cppwinrt/code_writers.h
Original file line number Diff line number Diff line change
Expand Up @@ -1131,15 +1131,21 @@ namespace cppwinrt
// we intentionally ignore errors when unregistering event handlers to be consistent with event_revoker
format = R"( template <typename D%> auto consume_%<D%>::%(%) const noexcept
{%
WINRT_IMPL_SHIM(%)->%(%);%
const auto castedResult = static_cast<% const&>(static_cast<D const&>(*this));
const auto abiType = *(abi_t<%>**)&castedResult;
check_cast_result(abiType);
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
abiType->%(%);%
}
)";
}
else
{
format = R"( template <typename D%> auto consume_%<D%>::%(%) const noexcept
{%
WINRT_VERIFY_(0, WINRT_IMPL_SHIM(%)->%(%));%
const auto castedResult = static_cast<% const&>(static_cast<D const&>(*this));
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
const auto abiType = *(abi_t<%>**)&castedResult;
check_cast_result(abiType);
WINRT_VERIFY_(0, abiType->%(%));%
}
)";
}
Expand All @@ -1148,7 +1154,10 @@ namespace cppwinrt
{
format = R"( template <typename D%> auto consume_%<D%>::%(%) const
{%
check_hresult(WINRT_IMPL_SHIM(%)->%(%));%
const auto castedResult = static_cast<% const&>(static_cast<D const&>(*this));
const auto abiType = *(abi_t<%>**)&castedResult;
check_cast_result(abiType);
check_hresult(abiType->%(%));%
}
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
)";
}
Expand All @@ -1161,6 +1170,7 @@ namespace cppwinrt
bind<write_consume_params>(signature),
bind<write_consume_return_type>(signature, false),
type,
type,
get_abi_name(method),
bind<write_abi_args>(signature),
bind<write_consume_return_statement>(signature));
Expand Down
22 changes: 22 additions & 0 deletions strings/base_error.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,28 @@ WINRT_EXPORT namespace winrt
return pointer;
}

template <typename T>
WINRT_IMPL_NOINLINE void check_cast_result(T* from WINRT_IMPL_SOURCE_LOCATION_ARGS)
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be a template. Can just take a void*. This allows the different instantiations to be folded.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fair. If I spin another PR I'll make that change.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't this have to be WINRT_IMPL_NOINLINE inline to avoid ODR problems?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What sort of ODR problems? The only variance based on compiler options is the SOURCE_LOCATION macro and that already has global ODR guards.

The noinline is important because it is how the code bloat is avoided by having lots of exception throwing locations spewed out into containing code.

{
if (!from)
{
com_ptr<impl::IRestrictedErrorInfo> restrictedError;
if (WINRT_IMPL_GetRestrictedErrorInfo(restrictedError.put_void()) == 0)
{
WINRT_IMPL_SetRestrictedErrorInfo(restrictedError.get());

int32_t code;
impl::bstr_handle description;
impl::bstr_handle restrictedDescription;
impl::bstr_handle capabilitySid;
if (restrictedError->GetErrorDetails(description.put(), &code, restrictedDescription.put(), capabilitySid.put()) == 0)
{
check_hresult(code WINRT_IMPL_SOURCE_LOCATION_FORWARD);
}
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if code gen would be better as

WINRT_IMPL_NOINLINE [[noreturn]] void throw_failed_cast(hresult code WINRT_IMPL_SOURCE_LOCATION_ARGS)
{
    throw hresult_error(code, take_ownership_from_abi WINRT_IMPL_SOURCE_LOCATION_FORWARD);
}

inline void check_cast_result(hresult code WINRT_IMPL_SOURCE_LOCATION_ARGS)
{
    if (code < 0) throw_failed_cast(code WINRT_IMPL_SOURCE_LOCATION_FORWARD);
}

where we have a new try_as_reason that returns both a pointer and an hresult.

void* result{};
hresult code = ptr->QueryInterface(guid_of<To>(), &result);
return { wrap_as_result<To>(result), code };


[[noreturn]] inline void terminate() noexcept
{
WINRT_IMPL_RoFailFastWithErrorContext(to_hresult());
Expand Down
3 changes: 3 additions & 0 deletions strings/base_extern.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ extern "C"
int32_t __stdcall WINRT_IMPL_SetThreadpoolTimerEx(winrt::impl::ptp_timer, void*, uint32_t, uint32_t) noexcept WINRT_IMPL_LINK(SetThreadpoolTimerEx, 16);
int32_t __stdcall WINRT_IMPL_SetThreadpoolWaitEx(winrt::impl::ptp_wait, void*, void*, void*) noexcept WINRT_IMPL_LINK(SetThreadpoolWaitEx, 16);
int32_t __stdcall WINRT_IMPL_RoOriginateLanguageException(int32_t error, void* message, void* exception) noexcept WINRT_IMPL_LINK(RoOriginateLanguageException, 12);
int32_t __stdcall WINRT_IMPL_RoCaptureErrorContext(int32_t error) noexcept WINRT_IMPL_LINK(RoCaptureErrorContext, 4);
void __stdcall WINRT_IMPL_RoFailFastWithErrorContext(int32_t) noexcept WINRT_IMPL_LINK(RoFailFastWithErrorContext, 4);
int32_t __stdcall WINRT_IMPL_RoTransformError(int32_t, int32_t, void*) noexcept WINRT_IMPL_LINK(RoTransformError, 12);
int32_t __stdcall WINRT_IMPL_GetRestrictedErrorInfo(void**) noexcept WINRT_IMPL_LINK(GetRestrictedErrorInfo, 4);
int32_t __stdcall WINRT_IMPL_SetRestrictedErrorInfo(void*) noexcept WINRT_IMPL_LINK(SetRestrictedErrorInfo, 4);

void* __stdcall WINRT_IMPL_LoadLibraryExW(wchar_t const* name, void* unused, uint32_t flags) noexcept WINRT_IMPL_LINK(LoadLibraryExW, 12);
int32_t __stdcall WINRT_IMPL_FreeLibrary(void* library) noexcept WINRT_IMPL_LINK(FreeLibrary, 4);
Expand Down
6 changes: 5 additions & 1 deletion strings/base_windows.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,11 @@ namespace winrt::impl
}

void* result{};
ptr->QueryInterface(guid_of<To>(), &result);
hresult code = ptr->QueryInterface(guid_of<To>(), &result);
if (code < 0)
{
WINRT_IMPL_RoCaptureErrorContext(code);
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
}
return wrap_as_result<To>(result);
}
}
Expand Down
25 changes: 25 additions & 0 deletions test/test/missing_required_interfaces.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "pch.h"

// Unset lean and mean so we can implement a type from the test_component namespace
#undef WINRT_LEAN_AND_MEAN
#include <winrt/test_component.h>

namespace
{
struct LiesAboutInheritance : public winrt::implements<LiesAboutInheritance, winrt::test_component::ILiesAboutInheritance>
{
LiesAboutInheritance() = default;
void StubMethod() {}
};
}

TEST_CASE("missing_required_interfaces")
{
auto lies = winrt::make_self<LiesAboutInheritance>().as<winrt::test_component::LiesAboutInheritance>();
REQUIRE(lies);
REQUIRE_NOTHROW(lies.StubMethod());

// The IStringable::ToString method does not exist on this type. In previous versions of cppwinrt
// this line would crash with a nullptr deference. It now throws an exception.
REQUIRE_THROWS_AS(lies.ToString(), winrt::hresult_error);
}
dmachaj marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions test/test/test.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@
<PrecompiledHeader>NotUsing</PrecompiledHeader>
</ClCompile>
<ClCompile Include="memory_buffer.cpp" />
<ClCompile Include="missing_required_interfaces.cpp" />
<ClCompile Include="module_lock_dll.cpp">
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">NotUsing</PrecompiledHeader>
<PrecompiledHeader Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">NotUsing</PrecompiledHeader>
Expand Down
7 changes: 7 additions & 0 deletions test/test_component/test_component.idl
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,13 @@ namespace test_component
static void StaticMethodWithAsyncReturn();
}

// This class declares that it implements another interface but under the covers it actually does
// not. This allows us to test the behavior when QI's that should not fail, do fail.
runtimeclass LiesAboutInheritance : Windows.Foundation.IStringable
{
void StubMethod();
}

namespace Structs
{
struct All
Expand Down
Loading