-
Notifications
You must be signed in to change notification settings - Fork 932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Performance optimization of JSON validation #16996
Performance optimization of JSON validation #16996
Conversation
This kind of change usually improves compile time as well. |
Warp parallel validation algorithm is slower than this. The transform takes slightly less time, But the subsequent copy_if and warp parallel validation takes more than this reduction in transform time. benchmark used to profile https://github.com/karthikeyann/cudf/blob/enh-profile_memusage_json/wm_benchmark.py Code for warp-parallel validationCodeconstexpr auto SINGLE_THREAD_THRESHOLD = 128;
// Unicode code point escape sequence
static constexpr char UNICODE_SEQ = 0x7F;
// Invalid escape sequence
static constexpr char NON_ESCAPE_CHAR = 0x7E;
// Unicode code point escape sequence comprises four hex characters
static constexpr size_type UNICODE_HEX_DIGIT_COUNT = 4;
/**
* @brief Returns the character to output for a given escaped character that's following a
* backslash.
*
* @param escaped_char The character following the backslash.
* @return The character to output for a given character that's following a backslash
*/
__device__ __forceinline__ char get_escape_char(char escaped_char)
{
switch (escaped_char) {
case '"': return '"';
case '\\': return '\\';
case '/': return '/';
case 'b': return '\b';
case 'f': return '\f';
case 'n': return '\n';
case 'r': return '\r';
case 't': return '\t';
case 'u': return UNICODE_SEQ;
default: return NON_ESCAPE_CHAR;
}
}
__device__ __forceinline__ bool is_hex(char ch) {
return (ch >= '0' && ch <= '9') || (ch >= 'A' && ch <= 'F') || (ch >= 'a' && ch <= 'f');
}
// Add extra condition in predicate before !validate_strings
// if ((token_indices[i] - token_indices[i - 1]) > SINGLE_THREAD_THRESHOLD) return false; // large strings are validated separately
{
nvtx3::scoped_range_in<libcudf_domain> ctx{"longstr_transform"};
// Algorithm:
// for(each warp: 32 chars)
// if !allow_unquoted_control_chars and char < 32, then error
// check if escaping_backslash
// if prev is escaping_backslash and curr is not a valid escape char, then error
// if prev is escaping_backslash and curr is 'u' and end-curr < 4, then error
// if curr is backslash and end, then error
auto validate_warp_strings = cuda::proclaim_return_type<bool>(
[data = d_input.data(),
allow_unquoted_control_chars =
options.is_allowed_unquoted_control_chars()] __device__(SymbolOffsetT start,
SymbolOffsetT end) -> bool {
auto in_begin = data + start;
auto in_end = data + end;
// auto warp_id = threadIdx.x / cudf::detail::warp_size;
bool init_state;
size_type lane = threadIdx.x % cudf::detail::warp_size;
init_state = false;
// using bitfield = bitfield_warp<num_warps>;
// bitfield is_slash;
// is_slash.reset(warp_id);
// __sync_warp();
for (thread_index_type char_index = lane;
char_index < cudf::util::round_up_safe(end - start, static_cast<unsigned>(cudf::detail::warp_size));
char_index += cudf::detail::warp_size) {
namespace cg = cooperative_groups;
auto thread_block = cg::this_thread_block();
auto tile = cg::tiled_partition<cudf::detail::warp_size>(thread_block);
bool const is_within_bounds = char_index < (in_end - in_begin);
auto const c = is_within_bounds ? in_begin[char_index] : '\0';
if (!allow_unquoted_control_chars) {
if (tile.any(is_within_bounds && static_cast<int>(c) >= 0 && static_cast<int>(c) < 32)) {
return false;
}
}
bool is_escaping_backslash{false};
struct state_table {
// using bit fields instead of state[2] TODO: try if it is faster
bool state0 : 1;
bool state1 : 1;
[[nodiscard]] bool inline __device__ get(bool init_state) const { return init_state ? state1 : state0; }
};
state_table curr{is_within_bounds && c == '\\', false}; // state transition vector.
auto composite_op = [](state_table op1, state_table op2) {
// equivalent of state_table{op2.state[op1.state[0]], op2.state[op1.state[1]]};
return state_table{op1.state0 ? op2.state1 : op2.state0,
op1.state1 ? op2.state1 : op2.state0};
};
state_table scanned = cg::inclusive_scan(tile, curr, composite_op);
// SlashScan(temp_slash[warp_id]).InclusiveScan(curr, scanned, composite_op);
is_escaping_backslash = scanned.get(init_state);
// init_state = __shfl_sync(~0u, is_escaping_backslash, BLOCK_SIZE - 1);
init_state = tile.shfl(is_escaping_backslash, cudf::detail::warp_size - 1);
// __syncwarp();
// is_slash.shift(warp_id);
// is_slash.set_bits(warp_id, is_escaping_backslash);
auto const next_escaped_char = [&]() {
bool const is_within_bounds = char_index + 1 < (end - start);
auto const next_c = is_within_bounds ? in_begin[char_index + 1] : '\0';
return get_escape_char(next_c);
}();
// String with parsing errors are made as null
bool error = false;
if (is_within_bounds) {
// curr=='\' and end, or prev=='\' and curr=='u' and end-curr < UNICODE_HEX_DIGIT_COUNT
// or prev=='\' and curr=='u' and end-curr >= UNICODE_HEX_DIGIT_COUNT and any non-hex
error |= (is_escaping_backslash /*c == '\\'*/ && char_index == (in_end - in_begin) - 1);
error |= (is_escaping_backslash && next_escaped_char == NON_ESCAPE_CHAR);
error |= (is_escaping_backslash && next_escaped_char == UNICODE_SEQ &&
((in_begin + char_index + 1 + UNICODE_HEX_DIGIT_COUNT >= in_end) |
!is_hex(in_begin[char_index + 2]) | !is_hex(in_begin[char_index + 3]) |
!is_hex(in_begin[char_index + 4]) | !is_hex(in_begin[char_index + 5])));
}
// error = __any_sync(~0u, error);
error = tile.any(error);
if (error) { return false; }
}
return true;
});
// reduce > SINGLE_THREAD_THRESHOLD size to get count, allocate, and validate using a warp
auto is_large_string = cuda::proclaim_return_type<bool>(
[tokens = tokens.begin(),
token_indices = token_indices.begin()] __device__(auto i) -> bool {
if (tokens[i] == token_t::FieldNameEnd || tokens[i] == token_t::StringEnd)
return (token_indices[i] - token_indices[i - 1]) > SINGLE_THREAD_THRESHOLD;
return false;
});
auto is_large_string2 = cuda::proclaim_return_type<size_type>(
[is_large_string, predicate, d_invalid = d_invalid.begin(),
token_indices = token_indices.begin()] __device__(auto i) -> size_type {
if (predicate(i)) d_invalid[i] = true;
return is_large_string(i);
});
auto num_large_strings =
[&](){
auto it = thrust::make_transform_iterator(count_it, is_large_string2);
nvtx3::scoped_range_in<libcudf_domain> ctx{"count_if"};
return thrust::reduce(
rmm::exec_policy(stream),
it,
it + num_tokens);
}();
rmm::device_uvector<size_type> large_string_indices(num_large_strings, stream);
{
nvtx3::scoped_range_in<libcudf_domain> ctx{"copy_if"};
thrust::copy_if(rmm::exec_policy(stream),
count_it,
count_it + num_tokens,
large_string_indices.begin(),
is_large_string);
}
// validate large strings using warp level parallelism
// can you write using for_each? only if shared memory is not used in scan.
auto count_it2 = thrust::make_counting_iterator<size_t>(0);
{
nvtx3::scoped_range_in<libcudf_domain> ctx{"for_each"};
thrust::for_each(rmm::exec_policy(stream),
count_it2,
count_it2 + num_large_strings * cudf::detail::warp_size,
[tokens = tokens.begin(),
token_indices = token_indices.begin(),
validate_warp_strings,
d_invalid = d_invalid.begin(),
large_string_indices = large_string_indices.begin()] __device__(auto i) {
// TODO warp scan code to validate string.
auto idx = i / cudf::detail::warp_size;
auto lane_id = i % cudf::detail::warp_size;
auto token_idx = large_string_indices[idx];
auto is_valid = validate_warp_strings(token_indices[token_idx - 1], token_indices[token_idx]);
if(is_valid == false and lane_id == 0) {
d_invalid[token_idx] = true;
}
});
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Amazing optimization! Got a few small questions and suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tests pass and this is a big performance boost. Just some quick testing showed a 1.5x speedup end to end on one benchmark that I ran.
/merge |
Description
As part of JSON validation, field, value and string tokens are validated. Right now the code has single transform_inclusive_scan. Since this transform functor is a heavy operation, it slows down the entire scan drastically.
This PR splits transform and scan in validation. The runtime of validation went from 200ms to 20ms.
Also, a few hardcoded string comparisons are moved to trie.
Checklist