Skip to content

Commit

Permalink
Changed ReturnCopy to ReturnCapture that can also capture by move + r…
Browse files Browse the repository at this point in the history
…emoved implicit ReturnCopy.
  • Loading branch information
FranckRJ committed Apr 21, 2024
1 parent e814a86 commit f507d1c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 41 deletions.
72 changes: 43 additions & 29 deletions include/fakeit/StubbingProgress.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ namespace fakeit {
return DoImpl(new Repeat<R, arglist...>(method));
}

virtual void AlwaysDo(std::function<R(const typename fakeit::test_arg<arglist>::type...)> method) {
DoImpl(new RepeatForever<R, arglist...>(method));
}

protected:
virtual MethodStubbingProgress<R, arglist...>& DoImpl(Action<R, arglist...> *action) = 0;
};
Expand All @@ -61,16 +65,49 @@ namespace fakeit {
template<typename R, typename ... arglist>
struct BasicReturnImpl<R, true, arglist...> : public BasicDoImpl<R, arglist...> {
using BasicDoImpl<R, arglist...>::Do;
using BasicDoImpl<R, arglist...>::AlwaysDo;

MethodStubbingProgress<R, arglist...>& Return(const R& r) {
return Do([&r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}

template <typename U = R>
MethodStubbingProgress<R, arglist...>& Return(typename std::remove_cv<typename std::remove_reference<R>::type>::type&& r) {
static_assert(sizeof(U) != sizeof(U), "Return() cannot take an rvalue references for functions returning a reference because it would make it dangling, use ReturnCapture() instead.");
return Return(r); // Only written to silence warning about not returning from a non-void function, but will never be executed.
}

void AlwaysReturn(const R &r) {
return AlwaysDo([&r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}

template <typename U = R>
void AlwaysReturn(typename std::remove_cv<typename std::remove_reference<R>::type>::type&&) {
static_assert(sizeof(U) != sizeof(U), "AlwaysReturn() cannot take an rvalue references for functions returning a reference because it would make it dangling, use AlwaysReturnCapture() instead.");
}

template<typename T>
MethodStubbingProgress<R, arglist...>& ReturnCapture(T&& r) {
auto store = std::make_shared<typename std::remove_reference<R>::type>(std::forward<T>(r));
return Do([store](const typename fakeit::test_arg<arglist>::type...) mutable -> R {
return std::forward<R>(*store);
});
}

template<typename T>
void AlwaysReturnCapture(T&& r) {
auto store = std::make_shared<typename std::remove_reference<R>::type>(std::forward<T>(r));
return AlwaysDo([store](const typename fakeit::test_arg<arglist>::type...) mutable -> R {
return std::forward<R>(*store);
});
}
};

// If R is not a reference.
template<typename R, typename ... arglist>
struct BasicReturnImpl<R, false, arglist...> : public BasicDoImpl<R, arglist...> {
using BasicDoImpl<R, arglist...>::Do;
using BasicDoImpl<R, arglist...>::AlwaysDo;

MethodStubbingProgress<R, arglist...>& Return(const R& r) {
return Do([r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
Expand All @@ -82,6 +119,10 @@ namespace fakeit {
return std::move(*store);
});
}

void AlwaysReturn(const R &r) {
return AlwaysDo([r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}
};

template<typename R, typename ... arglist>
Expand All @@ -97,13 +138,9 @@ namespace fakeit {

public:
using helper::BasicReturnImplHelper<R, arglist...>::Do;
using helper::BasicReturnImplHelper<R, arglist...>::AlwaysDo;
using helper::BasicReturnImplHelper<R, arglist...>::Return;

template<typename U = R>
typename std::enable_if<std::is_copy_constructible<U>::value, MethodStubbingProgress<R, arglist...>&>::type
ReturnCopy(const R& r) {
return Do([r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}
using helper::BasicReturnImplHelper<R, arglist...>::AlwaysReturn;

MethodStubbingProgress<R, arglist...> &
Return(const Quantifier<R> &q) {
Expand All @@ -119,25 +156,6 @@ namespace fakeit {
return Return(std::forward<Second>(s), std::forward<Tail>(t)...);
}


template<typename U = R>
typename std::enable_if<!std::is_reference<U>::value, void>::type
AlwaysReturn(const R &r) {
return AlwaysDo([r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}

template<typename U = R>
typename std::enable_if<std::is_reference<U>::value, void>::type
AlwaysReturn(const R &r) {
return AlwaysDo([&r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}

template<typename U = R>
typename std::enable_if<std::is_copy_constructible<U>::value, void>::type
AlwaysReturnCopy(const R& r) {
return AlwaysDo([r](const typename fakeit::test_arg<arglist>::type...) -> R { return r; });
}

MethodStubbingProgress<R, arglist...> &
Return() {
return Do([](const typename fakeit::test_arg<arglist>::type...) -> R { return DefaultValue<R>::value(); });
Expand Down Expand Up @@ -198,10 +216,6 @@ namespace fakeit {
return Do(s, t...);
}

virtual void AlwaysDo(std::function<R(const typename fakeit::test_arg<arglist>::type...)> method) {
DoImpl(new RepeatForever<R, arglist...>(method));
}

private:
MethodStubbingProgress &operator=(const MethodStubbingProgress &other) = delete;

Expand Down
24 changes: 12 additions & 12 deletions tests/referece_types_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ struct ReferenceTypesTests: tpunit::TestFixture {
int num{ 1 };

// explicit copy here
When(Method(mock, returnStringByConstRef)).ReturnCopy(a_string);
When(Method(mock, returnIntByRef)).ReturnCopy(num);
When(Method(mock, returnStringByConstRef)).ReturnCapture(a_string);
When(Method(mock, returnIntByRef)).ReturnCapture(num);

// modify now so know whether or not is was copied
a_string = "modified";
Expand All @@ -168,10 +168,10 @@ struct ReferenceTypesTests: tpunit::TestFixture {
Mock<ReferenceInterface> mock;

{
When(Method(mock, returnStringByConstRef)).Return(std::string{ "myConstRefString" });
When(Method(mock, returnStringByRef)).Return(std::string{ "myRefString" });
When(Method(mock, returnConcreteTypeByRef)).Return(ConcreteType(20));
When(Method(mock, returnIntByRef)).Return(1);
When(Method(mock, returnStringByConstRef)).ReturnCapture(std::string{ "myConstRefString" });
When(Method(mock, returnStringByRef)).ReturnCapture(std::string{ "myRefString" });
When(Method(mock, returnConcreteTypeByRef)).ReturnCapture(ConcreteType(20));
When(Method(mock, returnIntByRef)).ReturnCapture(1);
}

ReferenceInterface& i = mock.get();
Expand All @@ -186,10 +186,10 @@ struct ReferenceTypesTests: tpunit::TestFixture {
Mock<ReferenceInterface> mock;

{
When(Method(mock, returnStringByConstRef)).AlwaysReturn(std::string{ "myConstRefString" });
When(Method(mock, returnStringByRef)).AlwaysReturn(std::string{ "myRefString" });
When(Method(mock, returnConcreteTypeByRef)).AlwaysReturn(ConcreteType(20));
When(Method(mock, returnIntByRef)).AlwaysReturn(1);
When(Method(mock, returnStringByConstRef)).AlwaysReturnCapture(std::string{ "myConstRefString" });
When(Method(mock, returnStringByRef)).AlwaysReturnCapture(std::string{ "myRefString" });
When(Method(mock, returnConcreteTypeByRef)).AlwaysReturnCapture(ConcreteType(20));
When(Method(mock, returnIntByRef)).AlwaysReturnCapture(1);
}

ReferenceInterface& i = mock.get();
Expand Down Expand Up @@ -233,8 +233,8 @@ struct ReferenceTypesTests: tpunit::TestFixture {
int num{ 1 };

// explicit copy here
When(Method(mock, returnStringByConstRef)).AlwaysReturnCopy(a_string);
When(Method(mock, returnIntByRef)).AlwaysReturnCopy(num);
When(Method(mock, returnStringByConstRef)).AlwaysReturnCapture(a_string);
When(Method(mock, returnIntByRef)).AlwaysReturnCapture(num);

// modify now so know whether or not is was copied
a_string = "modified";
Expand Down

0 comments on commit f507d1c

Please sign in to comment.