Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Unique by Key Implementation for c.parallel #3947

Draft
wants to merge 39 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
682bcd2
Move unique_by_key kernels to NVRTC compilable header
NaderAlAwar Feb 14, 2025
1d4e74f
Add dynamic dispatch for unique_by_key
NaderAlAwar Feb 14, 2025
bc363a8
Add missing host device qualifier
NaderAlAwar Feb 14, 2025
9704eb9
Fix issue where argument name was not changed
NaderAlAwar Feb 14, 2025
3ac97f6
Merge branch 'main' into unique-by-key-dynamic-dispatch
NaderAlAwar Feb 17, 2025
00da18f
Forgot to rename local variable
NaderAlAwar Feb 17, 2025
440716e
Remove old TODO comment
NaderAlAwar Feb 18, 2025
240078d
Add initial c.parallel implementation of building unique_by_key
NaderAlAwar Feb 18, 2025
2a17133
Make VSMemHelper be a template parameter of invoke(). This is needed …
NaderAlAwar Feb 19, 2025
adc4acc
Revert "Make VSMemHelper be a template parameter of invoke(). This is…
NaderAlAwar Feb 20, 2025
5b210ec
Make vsmem_helper a template on the dispatch class so that we can pas…
NaderAlAwar Feb 20, 2025
dff4361
Refactor how we instantiate vsmem_helper
NaderAlAwar Feb 20, 2025
caed37f
Pass VSMemHelper as template parameter to unique by key kernel
NaderAlAwar Feb 21, 2025
f37206b
Move the template parameters of VSMemHelper to its methods to workaro…
NaderAlAwar Feb 24, 2025
d1c885c
Make vsmem helper functions host device
NaderAlAwar Feb 24, 2025
666803b
Merge branch 'unique-by-key-dynamic-dispatch' of https://github.com/n…
NaderAlAwar Feb 24, 2025
1ef5464
Merge branch 'main' into unique-by-key-dynamic-dispatch
NaderAlAwar Feb 24, 2025
c6ef86f
Remove unused macro
NaderAlAwar Feb 24, 2025
e5bffb6
Merge branch 'unique-by-key-dynamic-dispatch' into unique-by-key-c-pa…
NaderAlAwar Feb 24, 2025
0d57ebd
Update unique_by_key c parallel implementation following recent changes
NaderAlAwar Feb 24, 2025
d7d7928
Add tuning policies to c parallel unique_by_key
NaderAlAwar Feb 25, 2025
4d0372c
Make KeyT and ValueT templates for DispatchUniqueByKey
NaderAlAwar Feb 25, 2025
b37f88a
Merge branch 'unique-by-key-dynamic-dispatch' into unique-by-key-c-pa…
NaderAlAwar Feb 25, 2025
131a8a9
Add first c parallel unique_by_key test
NaderAlAwar Feb 25, 2025
c653ce2
Make vsmem helper methods static
NaderAlAwar Feb 25, 2025
e666a57
Move scan_tile_state to separate source file to reuse it in unique_by…
NaderAlAwar Feb 25, 2025
1ed7a3d
Add more tests for unique_by_key
NaderAlAwar Feb 25, 2025
895852d
Fix SFINAE for UniqueByKeyPolicyWrapper
NaderAlAwar Feb 25, 2025
6bde2ab
Merge branch 'unique-by-key-dynamic-dispatch' into unique-by-key-c-pa…
NaderAlAwar Feb 25, 2025
18aaae2
Fix checking output in unique_by_key test
NaderAlAwar Feb 25, 2025
5456587
Add none equal and all equal tests for c.parallel unique_by_key
NaderAlAwar Feb 26, 2025
dfe5cc7
Add custom type test for c.parallel unique_by_key
NaderAlAwar Feb 26, 2025
fb8f0aa
Add iterator test for c.parallel unique_by_key
NaderAlAwar Feb 26, 2025
cb1806c
Add support for output iterators for c.parallel unique_by_key
NaderAlAwar Feb 26, 2025
fa0f687
Add comment explaining vsmem design
NaderAlAwar Feb 26, 2025
6a284d2
Merge branch 'unique-by-key-dynamic-dispatch' into unique-by-key-c-pa…
NaderAlAwar Feb 26, 2025
b9a4201
Use smaller number of items for none equal unique by key test
NaderAlAwar Feb 26, 2025
1aed2b4
Run pre-commit
NaderAlAwar Feb 26, 2025
4dc9a3c
Merge branch 'unique-by-key-c-parallel' of https://github.com/NaderAl…
NaderAlAwar Feb 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions c/parallel/include/cccl/c/unique_by_key.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
//===----------------------------------------------------------------------===//
//
// Part of CUDA Experimental in CUDA Core Compute Libraries,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#pragma once

#ifndef CCCL_C_EXPERIMENTAL
# error "C exposure is experimental and subject to change. Define CCCL_C_EXPERIMENTAL to acknowledge this notice."
#endif // !CCCL_C_EXPERIMENTAL

#include <cuda.h>

#include <cccl/c/extern_c.h>
#include <cccl/c/types.h>

CCCL_C_EXTERN_C_BEGIN

typedef struct cccl_device_unique_by_key_build_result_t
{
int cc;
void* cubin;
size_t cubin_size;
CUlibrary library;
CUkernel compact_init_kernel;
CUkernel sweep_kernel;
size_t description_bytes_per_tile;
size_t payload_bytes_per_tile;
} cccl_device_unique_by_key_build_result_t;

CCCL_C_API CUresult cccl_device_unique_by_key_build(
cccl_device_unique_by_key_build_result_t* build,
cccl_iterator_t d_keys_in,
cccl_iterator_t d_values_in,
cccl_iterator_t d_keys_out,
cccl_iterator_t d_values_out,
cccl_iterator_t d_num_selected_out,
cccl_op_t op,
int cc_major,
int cc_minor,
const char* cub_path,
const char* thrust_path,
const char* libcudacxx_path,
const char* ctk_path) noexcept;

CCCL_C_API CUresult cccl_device_unique_by_key(
cccl_device_unique_by_key_build_result_t build,
void* d_temp_storage,
size_t* temp_storage_bytes,
cccl_iterator_t d_keys_in,
cccl_iterator_t d_values_in,
cccl_iterator_t d_keys_out,
cccl_iterator_t d_values_out,
cccl_iterator_t d_num_selected_out,
cccl_op_t op,
unsigned long long num_items,
CUstream stream) noexcept;

CCCL_C_API CUresult cccl_device_unique_by_key_cleanup(cccl_device_unique_by_key_build_result_t* bld_ptr) noexcept;

CCCL_C_EXTERN_C_END
18 changes: 9 additions & 9 deletions c/parallel/src/kernels/iterators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,28 +97,28 @@ std::string make_kernel_output_iterator(
const std::string iter_def = std::format(R"XXX(
extern "C" __device__ void DEREF(const void *self_ptr, VALUE_T x);
extern "C" __device__ void ADVANCE(void *self_ptr, DIFF_T offset);
struct __align__(OP_ALIGNMENT) output_iterator_state_t {{
struct __align__(OP_ALIGNMENT) {0}_state_t {{
char data[OP_SIZE];
}};
struct output_iterator_proxy_t {{
__device__ output_iterator_proxy_t operator=(VALUE_T x) {{
struct {0}_proxy_t {{
__device__ {0}_proxy_t operator=(VALUE_T x) {{
DEREF(&state, x);
return *this;
}}
output_iterator_state_t state;
{0}_state_t state;
}};
struct {0} {{
using iterator_category = cuda::std::random_access_iterator_tag;
using difference_type = DIFF_T;
using value_type = void;
using pointer = output_iterator_proxy_t*;
using reference = output_iterator_proxy_t;
__device__ output_iterator_proxy_t operator*() const {{ return {{state}}; }}
using pointer = {0}_proxy_t*;
using reference = {0}_proxy_t;
__device__ {0}_proxy_t operator*() const {{ return {{state}}; }}
__device__ {0}& operator+=(difference_type diff) {{
ADVANCE(&state, diff);
return *this;
}}
__device__ output_iterator_proxy_t operator[](difference_type diff) const {{
__device__ {0}_proxy_t operator[](difference_type diff) const {{
{0} result = *this;
result += diff;
return {{ result.state }};
Expand All @@ -128,7 +128,7 @@ struct {0} {{
result += diff;
return result;
}}
output_iterator_state_t state;
{0}_state_t state;
}};
)XXX",
iterator_name);
Expand Down
110 changes: 3 additions & 107 deletions c/parallel/src/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
//
//===----------------------------------------------------------------------===//

#include <cub/agent/single_pass_scan_operators.cuh>
#include <cub/detail/choose_offset.cuh>
#include <cub/detail/launcher/cuda_driver.cuh>
#include <cub/device/dispatch/dispatch_scan.cuh>
Expand All @@ -20,7 +19,6 @@
#include <format>
#include <iostream>
#include <optional>
#include <regex>
#include <string>
#include <type_traits>

Expand All @@ -30,6 +28,7 @@
#include "util/context.h"
#include "util/errors.h"
#include "util/indirect_arg.h"
#include "util/scan_tile_state.h"
#include "util/types.h"
#include <cccl/c/scan.h>
#include <nvrtc.h>
Expand Down Expand Up @@ -172,74 +171,6 @@ std::string get_scan_kernel_name(cccl_iterator_t input_it, cccl_iterator_t outpu
init_t); // 9
}

// TODO: NVRTC doesn't currently support extracting basic type
// information (e.g., type sizes and alignments) from compiled
// LTO-IR. So we separately compile a small PTX file that defines the
// necessary types and constants and grep it for the required
// information. If/when NVRTC adds these features, we can remove this
// extra compilation step and get the information directly from the
// LTO-IR.
static constexpr auto ptx_u64_assignment_regex = R"(\.visible\s+\.global\s+\.align\s+\d+\s+\.u64\s+{}\s*=\s*(\d+);)";

std::optional<size_t> find_size_t(char* ptx, std::string_view name)
{
std::regex regex(std::format(ptx_u64_assignment_regex, name));
std::cmatch match;
if (std::regex_search(ptx, match, regex))
{
auto result = std::stoi(match[1].str());
return result;
}
return std::nullopt;
}

struct scan_tile_state
{
// scan_tile_state implements the same (host) interface as cub::ScanTileStateT, except
// that it accepts the acummulator type as a runtime parameter rather than being
// templated on it.
//
// Both specializations ScanTileStateT<T, true> and ScanTileStateT<T, false> - where the
// bool parameter indicates whether `T` is primitive - are combined into a single type.

void* d_tile_status; // d_tile_descriptors
void* d_tile_partial;
void* d_tile_inclusive;

size_t description_bytes_per_tile;
size_t payload_bytes_per_tile;

scan_tile_state(size_t description_bytes_per_tile, size_t payload_bytes_per_tile)
: d_tile_status(nullptr)
, d_tile_partial(nullptr)
, d_tile_inclusive(nullptr)
, description_bytes_per_tile(description_bytes_per_tile)
, payload_bytes_per_tile(payload_bytes_per_tile)
{}

cudaError_t Init(int num_tiles, void* d_temp_storage, size_t temp_storage_bytes)
{
void* allocations[3] = {};
auto status = cub::detail::tile_state_init(
description_bytes_per_tile, payload_bytes_per_tile, num_tiles, d_temp_storage, temp_storage_bytes, allocations);
if (status != cudaSuccess)
{
return status;
}
d_tile_status = allocations[0];
d_tile_partial = allocations[1];
d_tile_inclusive = allocations[2];
return cudaSuccess;
}

cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes) const
{
temp_storage_bytes =
cub::detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles);
return cudaSuccess;
}
};

template <auto* GetPolicy>
struct dynamic_scan_policy_t
{
Expand Down Expand Up @@ -392,43 +323,8 @@ struct device_scan_policy {{
check(cuLibraryGetKernel(&build_ptr->init_kernel, build_ptr->library, init_kernel_lowered_name.c_str()));
check(cuLibraryGetKernel(&build_ptr->scan_kernel, build_ptr->library, scan_kernel_lowered_name.c_str()));

constexpr size_t num_ptx_args = 7;
const char* ptx_args[num_ptx_args] = {
arch.c_str(), cub_path, thrust_path, libcudacxx_path, ctk_path, "-rdc=true", "-dlto"};
constexpr size_t num_ptx_lto_args = 3;
const char* ptx_lopts[num_ptx_lto_args] = {"-lto", arch.c_str(), "-ptx"};

constexpr std::string_view ptx_src_template = R"XXX(
#include <cub/agent/single_pass_scan_operators.cuh>
#include <cub/util_type.cuh>
struct __align__({1}) storage_t {{
char data[{0}];
}};
__device__ size_t description_bytes_per_tile = cub::ScanTileState<{2}>::description_bytes_per_tile;
__device__ size_t payload_bytes_per_tile = cub::ScanTileState<{2}>::payload_bytes_per_tile;
)XXX";

const std::string ptx_src = std::format(ptx_src_template, accum_t.size, accum_t.alignment, accum_cpp);
auto compile_result =
make_nvrtc_command_list()
.add_program(nvrtc_translation_unit{ptx_src.c_str(), "tile_state_info"})
.compile_program({ptx_args, num_ptx_args})
.cleanup_program()
.finalize_program(num_ptx_lto_args, ptx_lopts);
auto ptx_code = compile_result.data.get();

size_t description_bytes_per_tile;
size_t payload_bytes_per_tile;
auto maybe_description_bytes_per_tile = scan::find_size_t(ptx_code, "description_bytes_per_tile");
if (maybe_description_bytes_per_tile)
{
description_bytes_per_tile = maybe_description_bytes_per_tile.value();
}
else
{
throw std::runtime_error("Failed to find description_bytes_per_tile in PTX");
}
payload_bytes_per_tile = scan::find_size_t(ptx_code, "payload_bytes_per_tile").value_or(0);
auto [description_bytes_per_tile,
payload_bytes_per_tile] = get_tile_state_bytes_per_tile(accum_t, accum_cpp, args, num_args, arch);

build_ptr->cc = cc;
build_ptr->cubin = (void*) result.data.release();
Expand Down
Loading