Skip to content

Commit

Permalink
adds tests for vsmem
Browse files Browse the repository at this point in the history
  • Loading branch information
elstehle committed Feb 20, 2024
1 parent c7ec7ff commit fdf565e
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 1 deletion.
1 change: 0 additions & 1 deletion cub/cub/device/dispatch/dispatch_select_if.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,6 @@ __launch_bounds__(int(
// Static shared memory allocation
__shared__ typename VsmemHelperT::static_temp_storage_t static_temp_storage;

// Shared memory for AgentSelectIf
// Get temporary storage
typename AgentSelectIfT::TempStorage& temp_storage = VsmemHelperT::get_temp_storage(static_temp_storage, vsmem);

Expand Down
90 changes: 90 additions & 0 deletions cub/test/catch2_test_device_select_if_vsmem.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/******************************************************************************
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
* * Neither the name of the NVIDIA CORPORATION nor the
* names of its contributors may be used to endorse or promote products
* derived from this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/

#include <cub/device/device_select.cuh>

#include <thrust/device_vector.h>
#include <thrust/host_vector.h>

#include <algorithm>

#include "catch2_test_helper.h"
#include "catch2_test_launch_helper.h"

// %PARAM% TEST_LAUNCH lid 0:1:2

DECLARE_LAUNCH_WRAPPER(cub::DeviceSelect::If, select_if);

using types = c2h::type_list<
// Type large enough to dispatch to the fallback policy
c2h::custom_type_t<c2h::equal_comparable_t, c2h::less_comparable_t, c2h::huge_data<256>::type>,
// Type large enough to require virtual shared memory
c2h::custom_type_t<c2h::equal_comparable_t, c2h::less_comparable_t, c2h::huge_data<512>::type>>;

template <typename T>
struct less_than_t
{
T compare;

explicit __host__ less_than_t(T compare)
: compare(compare)
{}

__host__ __device__ bool operator()(const T& a) const
{
return a < compare;
}
};

CUB_TEST("DeviceSelect::If works for large types", "[select_if][vsmem][device]", types)
{
using type = typename c2h::get<0, TestType>;

const int num_items = GENERATE_COPY(take(2, random(1, 10000)));
thrust::device_vector<type> in(num_items);
thrust::device_vector<type> out(num_items);
c2h::gen(CUB_SEED(2), in);

// just pick one of the input elements as boundary
less_than_t<type> le{in[num_items / 2]};

// Needs to be device accessible
thrust::device_vector<int> num_selected_out(1, 0);
int* d_first_num_selected_out = thrust::raw_pointer_cast(num_selected_out.data());

select_if(in.begin(), out.begin(), num_selected_out.begin(), num_items, le);

std::cout << "Selected: " << num_selected_out[0] << "/" << num_items << "\n";

// Ensure that we create the same output as std
thrust::host_vector<type> reference = in;
std::stable_partition(reference.begin(), reference.end(), le);

out.resize(num_selected_out[0]);
reference.resize(num_selected_out[0]);
REQUIRE(reference == out);
}

0 comments on commit fdf565e

Please sign in to comment.