Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance optimization of JSON validation #16996

Merged

Conversation

karthikeyann
Copy link
Contributor

@karthikeyann karthikeyann commented Oct 3, 2024

Description

As part of JSON validation, field, value and string tokens are validated. Right now the code has single transform_inclusive_scan. Since this transform functor is a heavy operation, it slows down the entire scan drastically.
This PR splits transform and scan in validation. The runtime of validation went from 200ms to 20ms.

Also, a few hardcoded string comparisons are moved to trie.

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.

@karthikeyann karthikeyann added 3 - Ready for Review Ready for review by team libcudf Affects libcudf (C++/CUDA) code. cuIO cuIO issue Performance Performance related issue Spark Functionality that helps Spark RAPIDS improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Oct 3, 2024
@karthikeyann karthikeyann requested a review from a team as a code owner October 3, 2024 22:55
@davidwendt
Copy link
Contributor

This kind of change usually improves compile time as well.
I looked up process_tokens.cu.o here and the compile time was about 3 minutes and is now 39 seconds in this PR.
Nice work.

@GregoryKimball GregoryKimball requested a review from shrshi October 4, 2024 17:05
@karthikeyann
Copy link
Contributor Author

karthikeyann commented Oct 4, 2024

Warp parallel validation algorithm is slower than this. The transform takes slightly less time, But the subsequent copy_if and warp parallel validation takes more than this reduction in transform time.

benchmark used to profile https://github.com/karthikeyann/cudf/blob/enh-profile_memusage_json/wm_benchmark.py
Profile of this PR
image
Profile of warp-parallel validation (copy_if large_string node ids, then warp-parallel transform)
image

Code for warp-parallel validation

Code

constexpr auto SINGLE_THREAD_THRESHOLD = 128;

// Unicode code point escape sequence
static constexpr char UNICODE_SEQ = 0x7F;

// Invalid escape sequence
static constexpr char NON_ESCAPE_CHAR = 0x7E;

// Unicode code point escape sequence comprises four hex characters
static constexpr size_type UNICODE_HEX_DIGIT_COUNT = 4;

/**
 * @brief Returns the character to output for a given escaped character that's following a
 * backslash.
 *
 * @param escaped_char The character following the backslash.
 * @return The character to output for a given character that's following a backslash
 */
__device__ __forceinline__ char get_escape_char(char escaped_char)
{
  switch (escaped_char) {
    case '"': return '"';
    case '\\': return '\\';
    case '/': return '/';
    case 'b': return '\b';
    case 'f': return '\f';
    case 'n': return '\n';
    case 'r': return '\r';
    case 't': return '\t';
    case 'u': return UNICODE_SEQ;
    default: return NON_ESCAPE_CHAR;
  }
}

__device__ __forceinline__ bool is_hex(char ch) {
  return (ch >= '0' && ch <= '9') || (ch >= 'A' && ch <= 'F') || (ch >= 'a' && ch <= 'f');
}

// Add extra condition in predicate before !validate_strings
// if ((token_indices[i] - token_indices[i - 1]) > SINGLE_THREAD_THRESHOLD) return false; // large strings are validated separately

{
    nvtx3::scoped_range_in<libcudf_domain> ctx{"longstr_transform"};
    // Algorithm:
    // for(each warp: 32 chars)
    // if  !allow_unquoted_control_chars and char < 32, then error
    // check if escaping_backslash
    // if prev is escaping_backslash and curr is not a valid escape char, then error
    // if prev is escaping_backslash and curr is 'u' and end-curr < 4, then error
    // if curr is backslash and end, then error
  auto validate_warp_strings = cuda::proclaim_return_type<bool>(
    [data = d_input.data(),
     allow_unquoted_control_chars =
       options.is_allowed_unquoted_control_chars()] __device__(SymbolOffsetT start,
                                                               SymbolOffsetT end) -> bool {
       auto in_begin = data + start;
       auto in_end = data + end;
      //  auto warp_id = threadIdx.x / cudf::detail::warp_size;
       bool init_state;
       size_type lane = threadIdx.x % cudf::detail::warp_size;
      init_state  = false;
      // using bitfield = bitfield_warp<num_warps>;
      // bitfield is_slash;
      // is_slash.reset(warp_id);
      // __sync_warp();
      for (thread_index_type char_index = lane;
         char_index < cudf::util::round_up_safe(end - start, static_cast<unsigned>(cudf::detail::warp_size));
         char_index += cudf::detail::warp_size) {

          namespace cg = cooperative_groups;
          auto thread_block = cg::this_thread_block();
          auto tile = cg::tiled_partition<cudf::detail::warp_size>(thread_block);

          bool const is_within_bounds = char_index < (in_end - in_begin);
          auto const c                = is_within_bounds ? in_begin[char_index] : '\0';
          if (!allow_unquoted_control_chars) {
            if (tile.any(is_within_bounds && static_cast<int>(c) >= 0 && static_cast<int>(c) < 32)) {
              return false;
            }
          }
          bool is_escaping_backslash{false};
          struct state_table {
            // using bit fields instead of state[2] TODO: try if it is faster
            bool state0 : 1;
            bool state1 : 1;
            [[nodiscard]] bool inline __device__ get(bool init_state) const { return init_state ? state1 : state0; }
          };
          state_table curr{is_within_bounds && c == '\\', false};  // state transition vector.
          auto composite_op = [](state_table op1, state_table op2) {
            // equivalent of state_table{op2.state[op1.state[0]], op2.state[op1.state[1]]};
            return state_table{op1.state0 ? op2.state1 : op2.state0,
                              op1.state1 ? op2.state1 : op2.state0};
          };
          state_table scanned = cg::inclusive_scan(tile, curr, composite_op);
          // SlashScan(temp_slash[warp_id]).InclusiveScan(curr, scanned, composite_op);
          is_escaping_backslash = scanned.get(init_state);
          // init_state            = __shfl_sync(~0u, is_escaping_backslash, BLOCK_SIZE - 1);
          init_state = tile.shfl(is_escaping_backslash, cudf::detail::warp_size - 1);
          // __syncwarp();
          // is_slash.shift(warp_id);
          // is_slash.set_bits(warp_id, is_escaping_backslash);
          auto const next_escaped_char = [&]() {
              bool const is_within_bounds = char_index + 1 < (end - start);
              auto const next_c                = is_within_bounds ? in_begin[char_index + 1] : '\0';
              return get_escape_char(next_c);
          }();
          // String with parsing errors are made as null
          bool error = false;
          if (is_within_bounds) {
            // curr=='\' and end, or prev=='\' and curr=='u' and end-curr < UNICODE_HEX_DIGIT_COUNT
            // or prev=='\' and curr=='u' and end-curr >= UNICODE_HEX_DIGIT_COUNT and any non-hex
            error |= (is_escaping_backslash /*c == '\\'*/ && char_index == (in_end - in_begin) - 1);
            error |= (is_escaping_backslash && next_escaped_char == NON_ESCAPE_CHAR);
            error |= (is_escaping_backslash && next_escaped_char == UNICODE_SEQ &&
                      ((in_begin + char_index + 1 + UNICODE_HEX_DIGIT_COUNT >= in_end) |
                      !is_hex(in_begin[char_index + 2]) | !is_hex(in_begin[char_index + 3]) |
                      !is_hex(in_begin[char_index + 4]) | !is_hex(in_begin[char_index + 5])));
          }
          // error = __any_sync(~0u, error);
          error = tile.any(error);
          if (error) { return false; }
        }
        return true;
    });

    // reduce > SINGLE_THREAD_THRESHOLD size to get count, allocate, and validate using a warp
    auto is_large_string = cuda::proclaim_return_type<bool>(
        [tokens        = tokens.begin(),
         token_indices = token_indices.begin()] __device__(auto i) -> bool {
          if (tokens[i] == token_t::FieldNameEnd || tokens[i] == token_t::StringEnd)
            return (token_indices[i] - token_indices[i - 1]) > SINGLE_THREAD_THRESHOLD;
          return false;
        });
    auto is_large_string2 = cuda::proclaim_return_type<size_type>(
    [is_large_string, predicate, d_invalid = d_invalid.begin(),
      token_indices = token_indices.begin()] __device__(auto i) -> size_type {
      if (predicate(i)) d_invalid[i] = true;
      return is_large_string(i);
    });
    auto num_large_strings =
    [&](){
      auto it = thrust::make_transform_iterator(count_it, is_large_string2);
      nvtx3::scoped_range_in<libcudf_domain> ctx{"count_if"};
      return thrust::reduce(
      rmm::exec_policy(stream),
      it,
      it + num_tokens);
    }();
    rmm::device_uvector<size_type> large_string_indices(num_large_strings, stream);
    {
      nvtx3::scoped_range_in<libcudf_domain> ctx{"copy_if"};
    thrust::copy_if(rmm::exec_policy(stream),
                    count_it,
                    count_it + num_tokens,
                    large_string_indices.begin(),
                    is_large_string);
    }
    // validate large strings using warp level parallelism    
    // can you write using for_each? only if shared memory is not used in scan.
    auto count_it2 = thrust::make_counting_iterator<size_t>(0);
    {
    nvtx3::scoped_range_in<libcudf_domain> ctx{"for_each"};
    thrust::for_each(rmm::exec_policy(stream),
                      count_it2,
                      count_it2 + num_large_strings * cudf::detail::warp_size,
                      [tokens        = tokens.begin(),
                        token_indices = token_indices.begin(),
                        validate_warp_strings,
                        d_invalid     = d_invalid.begin(), 
                        large_string_indices = large_string_indices.begin()] __device__(auto i) {
                        // TODO warp scan code to validate string.
                        auto idx = i / cudf::detail::warp_size;
                        auto lane_id = i % cudf::detail::warp_size;
                        auto token_idx = large_string_indices[idx];
                        auto is_valid = validate_warp_strings(token_indices[token_idx - 1], token_indices[token_idx]);
                        if(is_valid == false and lane_id == 0) {
                          d_invalid[token_idx] = true;
                        }
                      });
    }
  }
  

Copy link
Contributor

@vuule vuule left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing optimization! Got a few small questions and suggestions.

@karthikeyann karthikeyann requested a review from vuule October 7, 2024 19:44
@karthikeyann karthikeyann requested a review from ttnghia October 7, 2024 20:08
Copy link
Contributor

@revans2 revans2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests pass and this is a big performance boost. Just some quick testing showed a 1.5x speedup end to end on one benchmark that I ran.

@karthikeyann
Copy link
Contributor Author

/merge

@rapids-bot rapids-bot bot merged commit 553d8ec into rapidsai:branch-24.12 Oct 8, 2024
102 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
3 - Ready for Review Ready for review by team cuIO cuIO issue improvement Improvement / enhancement to an existing function libcudf Affects libcudf (C++/CUDA) code. non-breaking Non-breaking change Performance Performance related issue Spark Functionality that helps Spark RAPIDS
Projects
Status: No status
Development

Successfully merging this pull request may close these issues.

5 participants