Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 24 additions & 66 deletions stl/inc/mdspan
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ public:
}

template <class _ExtentsT>
friend constexpr pair<size_t, size_t> _Count_dynamic_extents_equal_to_zero_or_one(
friend constexpr size_t _Count_dynamic_extents_equal_to_one(
const _ExtentsT&) noexcept; // NB: used by 'layout_stride::mapping<E>::is_exhaustive'
};

Expand Down Expand Up @@ -797,26 +797,10 @@ concept _Layout_mapping_alike = requires {
bool_constant<_Mp::is_always_unique()>::value;
};

template <class _Extents, size_t _Val>
constexpr size_t _Count_static_extents_equal_to = 0;

template <class _IndexType, size_t... _Extents, size_t _Val>
constexpr size_t _Count_static_extents_equal_to<extents<_IndexType, _Extents...>, _Val> =
(static_cast<size_t>(_Extents == _Val) + ... + 0);

template <class _Extents>
_NODISCARD constexpr pair<size_t, size_t> _Count_dynamic_extents_equal_to_zero_or_one(const _Extents& _Exts) noexcept {
_NODISCARD constexpr size_t _Count_dynamic_extents_equal_to_one(const _Extents& _Exts) noexcept {
_STL_INTERNAL_STATIC_ASSERT(_Is_extents<_Extents> && _Extents::rank_dynamic() != 0);
size_t _Zero_extents = 0;
size_t _One_extents = 0;
for (const auto& _Ext : _Exts._Array) {
if (_Ext == 0) {
++_Zero_extents;
} else if (_Ext == 1) {
++_One_extents;
}
}
return {_Zero_extents, _One_extents};
return static_cast<size_t>(_RANGES count(_Exts._Array, static_cast<_Extents::index_type>(1)));
}

template <class _Extents>
Expand All @@ -833,7 +817,7 @@ public:
private:
using _Extents_base = _Maybe_fully_static_extents<extents_type>;
using _Strides_base = _Maybe_empty_array<index_type, _Extents::rank()>;
using _Stride_extent_pair = pair<rank_type, rank_type>;
using _Stride_extent_pair = pair<index_type, index_type>;

static_assert(_Is_extents<extents_type>,
"Extents must be a specialization of std::extents (N4950 [mdspan.layout.stride.overview]/2).");
Expand Down Expand Up @@ -869,7 +853,7 @@ private:
|| _Add_overflow(_Req_span_size, _Prod, _Req_span_size);
}

_Pairs[_Idx] = {static_cast<rank_type>(_Stride), static_cast<rank_type>(_Ext)};
_Pairs[_Idx] = {_Stride, _Ext};
}
_STL_VERIFY(_Found_zero || !_Overflow, "REQUIRED-SPAN-SIZE(e, s) must be representable as a value of type "
"index_type (N4950 [mdspan.layout.stride.cons]/4.2).");
Expand Down Expand Up @@ -998,7 +982,7 @@ public:
}

_NODISCARD static constexpr bool is_always_exhaustive() noexcept {
return false;
return extents_type::_Rank == 0 || extents_type::_Multidim_index_space_size_is_always_zero;
}

_NODISCARD static constexpr bool is_always_strided() noexcept {
Expand All @@ -1010,53 +994,28 @@ public:
}

_NODISCARD constexpr bool is_exhaustive() const noexcept {
constexpr size_t _Static_zero_extents = _Count_static_extents_equal_to<extents_type, 0>;
if constexpr (extents_type::rank() == 0) {
if constexpr (is_always_exhaustive()) {
return true;
} else if constexpr (extents_type::rank() == 1) {
return this->_Array[0] == 1;
} else if constexpr (_Static_zero_extents >= 2) {
// Per N5008 [mdspan.layout.stride.obs]/5.2, we are looking for a permutation P of integers in the range
// '[0, rank)' such that 'stride(p[i]) == stride(p[i-1])*extent(p[i-1])' is true for 'i' in the range
// '[1, rank)'. Knowing that at least two extents are equal to zero, we can deduce that such a permutation
// does not exist:
// - Some 'stride(p[j])' would have to be equal to 'stride(p[j-1])*extents(p[j-1]) = stride(p[j-1])*0 = 0'
// which is not possible.
// - Only 'extent(p[rank-1])' can be equal to 0, because it's not required to satisfy the condition above.
// Since we have two or more extents equal to 0 this is not possible either.
return false;
return this->_Array[0] == 1 || this->_Exts.extent(0) == 0;
} else if constexpr (extents_type::rank() == 2) {
return (this->_Array[0] == 1 && this->_Array[1] == this->_Exts.extent(0))
|| (this->_Array[1] == 1 && this->_Array[0] == this->_Exts.extent(1));
} else {
// NB: Extents equal to 1 are problematic too - sometimes in such cases even when the mapping is exhaustive
|| (this->_Array[1] == 1 && this->_Array[0] == this->_Exts.extent(1)) || this->_Exts.extent(0) == 0
|| this->_Exts.extent(1) == 0;
} else if constexpr (_RANGES count(extents_type::_Static_extents, static_cast<rank_type>(1)) != 0) {
// NB: Extents equal to 1 are problematic - sometimes in such cases even when the mapping is exhaustive
// this function should return false.
// For example, when the extents are [2, 1, 2] and the strides are [1, 5, 2], the mapping is exhaustive
// per N5008 [mdspan.layout.reqmts]/16 but not per N5008 [mdspan.layout.stride.obs]/5.2.
constexpr size_t _Static_zero_or_one_extents =
_Static_zero_extents + _Count_static_extents_equal_to<extents_type, 1>;

if constexpr (extents_type::rank_dynamic() != 0) {
const auto [_Dynamic_zero_extents, _Dynamic_one_extents] =
_Count_dynamic_extents_equal_to_zero_or_one(this->_Exts);

const size_t _All_zero_extents = _Static_zero_extents + _Dynamic_zero_extents;
if (_All_zero_extents >= 2) {
return false;
}

const size_t _All_zero_or_one_extents =
_Static_zero_or_one_extents + _Dynamic_zero_extents + _Dynamic_one_extents;
if (_All_zero_or_one_extents == 0) {
return _Is_exhaustive_common_case();
}

return _Is_exhaustive_special_case();
} else if constexpr (_Static_zero_or_one_extents == 0) {
return _Is_exhaustive_common_case();
} else {
return _Is_exhaustive_special_case();
} else if constexpr (extents_type::rank_dynamic() == 0) {
return _Is_exhaustive_common_case();
} else {
if (_Count_dynamic_extents_equal_to_one(this->_Exts) != 0) {
return _Is_exhaustive_special_case();
}

return _Is_exhaustive_common_case();
}
}

Expand Down Expand Up @@ -1127,20 +1086,19 @@ private:
}

_NODISCARD constexpr bool _Is_exhaustive_common_case() const noexcept {
return required_span_size() == _Fwd_prod_of_extents<extents_type>::_Calculate(this->_Exts, extents_type::_Rank);
const index_type _Prod = _Fwd_prod_of_extents<extents_type>::_Calculate(this->_Exts, extents_type::_Rank);
return _Prod == required_span_size() || _Prod == 0;
}

_NODISCARD constexpr bool _Is_exhaustive_special_case() const noexcept {
array<_Stride_extent_pair, extents_type::rank()> _Pairs;
for (rank_type _Idx = 0; _Idx < extents_type::_Rank; ++_Idx) {
rank_type _Ext = static_cast<rank_type>(this->_Exts.extent(_Idx));
const index_type _Ext = this->_Exts.extent(_Idx);
if (_Ext == 0) {
// NB: _Ext equal to zero is special - we want it to end up as close to the end of the sorted range as
// possible, so we assign max value of rank_type to it.
_Ext = static_cast<rank_type>(-1);
return true;
}

_Pairs[_Idx] = {static_cast<rank_type>(this->_Array[_Idx]), _Ext};
_Pairs[_Idx] = {this->_Array[_Idx], _Ext};
}

_RANGES sort(_Pairs);
Expand Down
7 changes: 4 additions & 3 deletions tests/libcxx/expected_results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,13 @@ std/ranges/range.adaptors/range.lazy.split/range.lazy.split.outer.value/ctor.ite
# libc++ doesn't implement LWG-4112
std/ranges/range.adaptors/range.join/range.join.iterator/arrow.pass.cpp FAIL

# libc++ doesn't implement LWG-4266 "`layout_stride::mapping` should treat empty mappings as exhaustive"
std/containers/views/mdspan/layout_stride/is_exhaustive_corner_case.pass.cpp FAIL
std/containers/views/mdspan/layout_stride/properties.pass.cpp FAIL

# If any feature-test macro test is failing, this consolidated test will also fail.
std/language.support/support.limits/support.limits.general/version.version.compile.pass.cpp FAIL

# libc++ incorrectly implements `layout_stride::mapping<E>::is_exhaustive()`
std/containers/views/mdspan/layout_stride/is_exhaustive_corner_case.pass.cpp FAIL


# *** INTERACTIONS WITH MSVC THAT UPSTREAM LIKELY WON'T FIX ***
# These tests set an allocator with a max_size() too small to default construct an unordered container
Expand Down
38 changes: 19 additions & 19 deletions tests/std/tests/P0009R18_mdspan_layout_stride/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ constexpr void do_check_members(const extents<IndexType, Extents...>& ext,

{ // Check 'is_always_[unique/exhaustive/strided]' functions
static_assert(Mapping::is_always_unique());
static_assert(!Mapping::is_always_exhaustive());
static_assert(Mapping::is_always_exhaustive() == (Ext::rank() == 0 || ((Extents == 0) || ...)));
static_assert(Mapping::is_always_strided());
}

Expand Down Expand Up @@ -410,7 +410,7 @@ constexpr void check_is_exhaustive() {

// rank() is equal to 1
check(extents<int, 0>{}, array{1}, true);
check(dextents<int, 1>{0}, array{2}, false);
check(dextents<int, 1>{0}, array{2}, true);
check(extents<int, 1>{}, array{3}, false);
check(dextents<int, 1>{2}, array{2}, false);
check(extents<int, 3>{}, array{1}, true);
Expand All @@ -426,14 +426,14 @@ constexpr void check_is_exhaustive() {
check(extents<int, dynamic_extent, 7>{5}, array{1, 8}, false);
check(dextents<int, 2>{6, 5}, array{1, 10}, false);
check(extents<int, 0, 3>{}, array{3, 1}, true);
check(extents<int, 0, 3>{}, array{6, 2}, false);
check(extents<int, dynamic_extent, 3>{0}, array{6, 1}, false);
check(extents<int, 0, dynamic_extent>{3}, array{6, 2}, false);
check(dextents<int, 2>{0, 3}, array{7, 2}, false);
check(extents<int, 0, 0>{}, array{1, 1}, false);
check(extents<int, 0, dynamic_extent>{0, 0}, array{1, 1}, false);
check(dextents<int, 2>{0, 0}, array{1, 2}, false);
check(extents<int, 1, dynamic_extent>{0}, array{1, 2}, false);
check(extents<int, 0, 3>{}, array{6, 2}, true);
check(extents<int, dynamic_extent, 3>{0}, array{6, 1}, true);
check(extents<int, 0, dynamic_extent>{3}, array{6, 2}, true);
check(dextents<int, 2>{0, 3}, array{7, 2}, true);
check(extents<int, 0, 0>{}, array{1, 1}, true);
check(extents<int, 0, dynamic_extent>{0, 0}, array{1, 1}, true);
check(dextents<int, 2>{0, 0}, array{1, 2}, true);
check(extents<int, 1, dynamic_extent>{0}, array{1, 2}, true);

// rank() is greater than 2
check(extents<int, 2, 3, 5>{}, array{1, 2, 6}, true);
Expand All @@ -449,20 +449,20 @@ constexpr void check_is_exhaustive() {
// rank() is greater than 2 and some extents are equal to 0
check(extents<int, 2, 0, 7>{}, array{7, 14, 1}, true);
check(extents<int, dynamic_extent, 0, 7>{2}, array{1, 14, 2}, true);
check(extents<int, 2, dynamic_extent, 7>{0}, array{14, 28, 1}, false);
check(extents<int, 2, dynamic_extent, dynamic_extent>{0, 7}, array{1, 2, 2}, false);
check(dextents<int, 3>{2, 0, 7}, array{2, 28, 4}, false);
check(extents<int, 5, 0, 0>{}, array{3, 1, 1}, false);
check(extents<int, 5, dynamic_extent, 0>{0}, array{1, 5, 1}, false);
check(dextents<int, 3>{5, 0, 0}, array{2, 1, 10}, false);
check(extents<int, 0, 0, 0>{}, array{1, 1, 1}, false);
check(extents<int, 2, dynamic_extent, 7>{0}, array{14, 28, 1}, true);
check(extents<int, 2, dynamic_extent, dynamic_extent>{0, 7}, array{1, 2, 2}, true);
check(dextents<int, 3>{2, 0, 7}, array{2, 28, 4}, true);
check(extents<int, 5, 0, 0>{}, array{3, 1, 1}, true);
check(extents<int, 5, dynamic_extent, 0>{0}, array{1, 5, 1}, true);
check(dextents<int, 3>{5, 0, 0}, array{2, 1, 10}, true);
check(extents<int, 0, 0, 0>{}, array{1, 1, 1}, true);
check(extents<int, 0, 1, 1>{}, array{1, 1, 1}, true);

// rank() is greater than 2 - one extent is equal to 0 while others are equal to each other
check(extents<int, 3, 0, 3>{}, array{1, 9, 3}, true);
check(extents<int, dynamic_extent, 0, 3>{3}, array{3, 9, 1}, true);
check(extents<int, 3, dynamic_extent, dynamic_extent>{0, 3}, array{1, 3, 3}, false);
check(dextents<int, 3>{3, 0, 3}, array{1, 4, 8}, false);
check(extents<int, 3, dynamic_extent, dynamic_extent>{0, 3}, array{1, 3, 3}, true);
check(dextents<int, 3>{3, 0, 3}, array{1, 4, 8}, true);
check(dextents<int, 3>{0, 1, 1}, array{1, 1, 1}, true);

// required_span_size() is equal to 1
Expand Down