diff --git a/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp b/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp index 79ef78974b9c..f4eb329b7957 100644 --- a/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp +++ b/barretenberg/cpp/src/barretenberg/common/bb_bench.cpp @@ -37,7 +37,7 @@ std::string format_time(double time_ms) std::ostringstream oss; if (time_ms >= 1000.0) { oss << std::fixed << std::setprecision(2) << (time_ms / 1000.0) << " s"; - } else if (time_ms >= 1.0) { + } else if (time_ms >= 1.0 && time_ms < 1000.0) { oss << std::fixed << std::setprecision(2) << time_ms << " ms"; } else { oss << std::fixed << std::setprecision(1) << (time_ms * 1000.0) << " μs"; @@ -50,42 +50,75 @@ std::string format_time_aligned(double time_ms) { std::ostringstream oss; if (time_ms >= 1000.0) { - std::string time_str = - (std::ostringstream{} << std::fixed << std::setprecision(2) << (time_ms / 1000.0) << "s").str(); - oss << std::left << std::setw(10) << time_str; + std::ostringstream time_oss; + time_oss << std::fixed << std::setprecision(2) << (time_ms / 1000.0) << "s"; + oss << std::left << std::setw(10) << time_oss.str(); } else { - std::string time_str = (std::ostringstream{} << std::fixed << std::setprecision(1) << time_ms << "ms").str(); - oss << std::left << std::setw(10) << time_str; + std::ostringstream time_oss; + time_oss << std::fixed << std::setprecision(1) << time_ms << "ms"; + oss << std::left << std::setw(10) << time_oss.str(); } return oss.str(); } +// Helper to format percentage value +std::string format_percentage_value(double percentage, const char* color) +{ + std::ostringstream oss; + oss << color << " " << std::left << std::fixed << std::setprecision(1) << std::setw(5) << percentage << "%" + << Colors::RESET; + return oss.str(); +} + // Helper to format percentage with color based on percentage value std::string format_percentage(double value, double total, double min_threshold = 0.0) { - if (total <= 0) { - return " "; - } - double percentage = (value / total) * 100.0; - if (percentage < min_threshold) { + double percentage = (total <= 0) ? 0.0 : (value / total) * 100.0; + if (total <= 0 || percentage < min_threshold) { return " "; } // Choose color based on percentage value (like time colors) const char* color = Colors::CYAN; // Default color + return format_percentage_value(percentage, color); +} + +// Helper to format percentage section +std::string format_percentage_section(double time_ms, double parent_time, size_t indent_level) +{ + if (parent_time > 0 && indent_level > 0) { + return format_percentage(time_ms * 1000000.0, parent_time); + } + return " "; +} + +// Helper to format time section +std::string format_time_section(double time_ms) +{ std::ostringstream oss; - oss << color << " " << std::left << std::fixed << std::setprecision(1) << std::setw(5) << percentage << "%" - << Colors::RESET; + oss << " "; + if (time_ms >= 100.0 && time_ms < 1000.0) { + oss << Colors::DIM << format_time_aligned(time_ms) << Colors::RESET; + } else { + oss << format_time_aligned(time_ms); + } + return oss.str(); +} + +// Helper to format call stats +std::string format_call_stats(double time_ms, uint64_t count) +{ + if (!(time_ms >= 100.0 && count > 1)) { + return ""; + } + double avg_ms = time_ms / static_cast(count); + std::ostringstream oss; + oss << Colors::DIM << " (" << format_time(avg_ms) << " x " << count << ")" << Colors::RESET; return oss.str(); } -std::string format_aligned_section(double time_ms, - double parent_time, - uint64_t count, - size_t indent_level, - size_t num_threads = 1, - double mean_ms = 0.0) +std::string format_aligned_section(double time_ms, double parent_time, uint64_t count, size_t indent_level) { std::ostringstream oss; @@ -93,29 +126,13 @@ std::string format_aligned_section(double time_ms, oss << Colors::MAGENTA << "[" << indent_level << "] " << Colors::RESET; // Format percentage FIRST - if (parent_time > 0 && indent_level > 0) { - oss << format_percentage(time_ms * 1000000.0, parent_time); - } else { - oss << " "; // Keep alignment for root entries - } + oss << format_percentage_section(time_ms, parent_time, indent_level); // Format time AFTER percentage with appropriate color (with more spacing) - if (time_ms >= 100.0 && time_ms < 1000.0) { - oss << " " << Colors::DIM << format_time_aligned(time_ms) << Colors::RESET; - } else { - oss << " " << format_time_aligned(time_ms); - } + oss << format_time_section(time_ms); // Format calls/threads info - only show for >= 100ms, make it DIM - if (time_ms >= 100.0) { - if (num_threads > 1) { - oss << Colors::DIM << " (" << std::fixed << std::setprecision(2) << mean_ms << " ms x " << num_threads - << ")" << Colors::RESET; - } else if (count > 1) { - double avg_ms = time_ms / static_cast(count); - oss << Colors::DIM << " (" << format_time(avg_ms) << " x " << count << ")" << Colors::RESET; - } - } + oss << format_call_stats(time_ms, count); return oss.str(); } @@ -129,12 +146,12 @@ struct TimeColor { TimeColor get_time_colors(double time_ms) { if (time_ms >= 1000.0) { - return { Colors::BOLD, Colors::WHITE }; // Bold white for >= 1 second + return { Colors::BOLD, Colors::WHITE }; } if (time_ms >= 100.0) { - return { Colors::YELLOW, Colors::YELLOW }; // Yellow for >= 100ms + return { Colors::YELLOW, Colors::YELLOW }; } - return { Colors::DIM, Colors::DIM }; // Dim for < 100ms + return { Colors::DIM, Colors::DIM }; } // Print separator line @@ -164,12 +181,12 @@ void AggregateEntry::add_thread_time_sample(const TimeAndCount& stats) // Account for aggregate time and count time += stats.time; count += stats.count; + time_max = std::max(stats.time, time_max); // Use Welford's method to be able to track the variance - double time_ms = static_cast(stats.time / stats.count) / 1000000.0; num_threads++; - double delta = time_ms - time_mean; + double delta = static_cast(stats.time) - time_mean; time_mean += delta / static_cast(num_threads); - double delta2 = time_ms - time_mean; + double delta2 = static_cast(stats.time) - time_mean; time_m2 += delta * delta2; } @@ -261,7 +278,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts(std::ostream& os, size_t // Loop for a flattened view uint64_t time = 0; for (auto& [parent_key, entry] : entry_map) { - time += entry.time; + time += entry.time_max; } if (!first) { @@ -307,10 +324,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream // Helper function to print a stat line with tree drawing auto print_entry = [&](const AggregateEntry& entry, size_t indent_level, bool is_last, uint64_t parent_time) { std::string indent(indent_level * 2, ' '); - std::string prefix; - if (indent_level > 0) { - prefix = is_last ? "└─ " : "├─ "; - } + std::string prefix = (indent_level == 0) ? "" : (is_last ? "└─ " : "├─ "); // Use exactly 80 characters for function name without indent const size_t name_width = 80; @@ -319,7 +333,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream display_name = display_name.substr(0, name_width - 3) + "..."; } - double time_ms = static_cast(entry.time) / 1000000.0; + double time_ms = static_cast(entry.time_max) / 1000000.0; auto colors = get_time_colors(time_ms); // Print indent + prefix + name (exactly 80 chars) + time/percentage/calls @@ -330,29 +344,24 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream os << std::left << std::setw(static_cast(name_width)) << display_name << Colors::RESET; // Print time if available with aligned section including indent level - if (entry.time > 0) { + if (entry.time_max > 0) { if (time_ms < 100.0) { // Minimal format for <100ms: only [level] and percentage, no time display std::ostringstream minimal_oss; minimal_oss << Colors::MAGENTA << "[" << indent_level << "] " << Colors::RESET; - - // Format percentage FIRST - if (parent_time > 0 && indent_level > 0) { - minimal_oss << format_percentage(time_ms * 1000000.0, static_cast(parent_time)); - } else { - minimal_oss << " "; // Keep alignment for root entries - } - - // Add spacing to replace where time would be (with extra spacing to match) - minimal_oss << " " << std::setw(10) << ""; - + minimal_oss << format_percentage_section(time_ms, static_cast(parent_time), indent_level); + minimal_oss << " " << std::setw(10) << ""; // Add spacing to replace where time would be os << " " << colors.time_color << std::setw(40) << std::left << minimal_oss.str() << Colors::RESET; } else { - // Full format for >=100ms - double mean_ms = entry.num_threads > 1 ? entry.time_mean / 1000000.0 : 0.0; - std::string aligned_section = format_aligned_section( - time_ms, static_cast(parent_time), entry.count, indent_level, entry.num_threads, mean_ms); + std::string aligned_section = + format_aligned_section(time_ms, static_cast(parent_time), entry.count, indent_level); os << " " << colors.time_color << std::setw(40) << std::left << aligned_section << Colors::RESET; + if (entry.num_threads > 1) { + double mean_ms = entry.time_mean / 1000000.0; + double stddev_percentage = floor(entry.get_std_dev() * 100 / entry.time_mean); + os << " " << entry.num_threads << " threads " << mean_ms << "ms average " << stddev_percentage + << "% stddev"; + } } } @@ -392,7 +401,7 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream if (!printed_in_detail.contains(key)) { for (const auto& [child_key, parent_map] : aggregated) { for (const auto& [parent_key, entry] : parent_map) { - if (parent_key == key && entry.time >= 500000) { // 0.5ms in nanoseconds + if (parent_key == key && entry.time_max >= 500000) { // 0.5ms in nanoseconds children.push_back(child_key); break; } @@ -405,27 +414,22 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream std::ranges::sort(children, [&](OperationKey a, OperationKey b) { uint64_t time_a = 0; uint64_t time_b = 0; - - // Get time for child 'a' when called from THIS parent if (auto it = aggregated.find(a); it != aggregated.end()) { for (const auto& [parent_key, entry] : it->second) { if (parent_key == key) { - time_a = entry.time; + time_a = entry.time_max; break; } } } - - // Get time for child 'b' when called from THIS parent if (auto it = aggregated.find(b); it != aggregated.end()) { for (const auto& [parent_key, entry] : it->second) { if (parent_key == key) { - time_b = entry.time; + time_b = entry.time_max; break; } } } - return time_a > time_b; }); @@ -433,25 +437,21 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream uint64_t children_total_time = 0; for (const auto& child_key : children) { if (auto it = aggregated.find(child_key); it != aggregated.end()) { - // Sum time for this child across all parent contexts where parent matches current key for (const auto& [parent_key, entry] : it->second) { - if (parent_key == key && entry.time >= 500000) { // 0.5ms in nanoseconds - children_total_time += entry.time; + if (parent_key == key && entry.time_max >= 500000) { // 0.5ms in nanoseconds + children_total_time += entry.time_max; } } } } - - // Check if there's significant unaccounted time (>5%) and we have children - uint64_t parent_total_time = entry_to_print->time; + uint64_t parent_total_time = entry_to_print->time_max; bool should_add_other = false; - uint64_t other_time = 0; if (!children.empty() && parent_total_time > 0 && children_total_time < parent_total_time) { - other_time = parent_total_time - children_total_time; - double other_percentage = - (static_cast(other_time) / static_cast(parent_total_time)) * 100.0; - should_add_other = other_percentage > 5.0 && other_time > 0; + uint64_t unaccounted = parent_total_time - children_total_time; + double percentage = (static_cast(unaccounted) / static_cast(parent_total_time)) * 100.0; + should_add_other = percentage > 5.0 && unaccounted > 0; } + uint64_t other_time = should_add_other ? (parent_total_time - children_total_time) : 0; if (!children.empty() && keys_to_parents[key].size() > 1) { os << std::string(indent_level * 2, ' ') << " ├─ NOTE: Shared children. Can add up to > 100%.\n"; @@ -465,13 +465,12 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream // Print "(other)" category if significant unaccounted time exists if (should_add_other && keys_to_parents[key].size() <= 1) { - // Create fake AggregateEntry for (other) AggregateEntry other_entry; other_entry.key = "(other)"; other_entry.time = other_time; + other_entry.time_max = other_time; other_entry.count = 1; other_entry.num_threads = 1; - print_entry(other_entry, indent_level + 1, true, parent_total_time); // always last } }; @@ -479,26 +478,24 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream // Find root entries (those that ONLY have empty parent key and significant time) std::vector roots; for (const auto& [key, parent_map] : aggregated) { - // Check if this operation has an empty parent key entry with significant time - if (auto empty_parent_it = parent_map.find(""); empty_parent_it != parent_map.end()) { - if (empty_parent_it->second.time > 0) { - roots.push_back(key); - } + auto empty_parent_it = parent_map.find(""); + if (empty_parent_it != parent_map.end() && empty_parent_it->second.time > 0) { + roots.push_back(key); } } - // Sort roots by time + // Sort roots by time (descending) std::ranges::sort(roots, [&](OperationKey a, OperationKey b) { uint64_t time_a = 0; uint64_t time_b = 0; - if (auto it = aggregated.find(a); it != aggregated.end()) { - if (auto parent_it = it->second.find(""); parent_it != it->second.end()) { - time_a = parent_it->second.time; + if (auto it_a = aggregated.find(a); it_a != aggregated.end()) { + if (auto parent_it = it_a->second.find(""); parent_it != it_a->second.end()) { + time_a = parent_it->second.time_max; } } - if (auto it = aggregated.find(b); it != aggregated.end()) { - if (auto parent_it = it->second.find(""); parent_it != it->second.end()) { - time_b = parent_it->second.time; + if (auto it_b = aggregated.find(b); it_b != aggregated.end()) { + if (auto parent_it = it_b->second.find(""); parent_it != it_b->second.end()) { + time_b = parent_it->second.time_max; } } return time_a > time_b; @@ -513,41 +510,47 @@ void GlobalBenchStatsContainer::print_aggregate_counts_hierarchical(std::ostream print_separator(os, false); // Calculate totals from root entries - uint64_t total_time = 0; - uint64_t total_calls = 0; - std::set unique_functions; - uint64_t shared_count = 0; - - for (auto& [key, parent_map] : aggregated) { - unique_functions.insert(key); + std::set unique_funcs; + for (const auto& [key, _] : aggregated) { + unique_funcs.insert(key); + } + size_t unique_functions_count = unique_funcs.size(); - // Count as shared if has multiple parents - if (keys_to_parents[key].size() > 1) { + uint64_t shared_count = 0; + for (const auto& [key, parents] : keys_to_parents) { + if (parents.size() > 1) { shared_count++; } - auto& root_entry = parent_map[""]; - total_time += root_entry.time; - // Sum ALL calls - for (auto& entry : parent_map) { - total_calls += entry.second.count; + } + + uint64_t total_time = 0; + for (const auto& [_, parent_map] : aggregated) { + if (auto it = parent_map.find(""); it != parent_map.end()) { + total_time = std::max(total_time, it->second.time_max); + } + } + + uint64_t total_calls = 0; + for (const auto& [_, parent_map] : aggregated) { + for (const auto& [__, entry] : parent_map) { + total_calls += entry.count; } } double total_time_ms = static_cast(total_time) / 1000000.0; - os << " " << Colors::BOLD << "Total: " << Colors::RESET << Colors::MAGENTA << unique_functions.size() + os << " " << Colors::BOLD << "Total: " << Colors::RESET << Colors::MAGENTA << unique_functions_count << " functions" << Colors::RESET; if (shared_count > 0) { os << " (" << Colors::RED << shared_count << " shared" << Colors::RESET << ")"; } - os << ", " << Colors::GREEN << total_calls << " measurements" << Colors::RESET << ", "; - + os << ", " << Colors::GREEN << total_calls << " measurements" << Colors::RESET << ", " << Colors::YELLOW; if (total_time_ms >= 1000.0) { - os << Colors::YELLOW << std::fixed << std::setprecision(2) << (total_time_ms / 1000.0) << " seconds" - << Colors::RESET; + os << std::fixed << std::setprecision(2) << (total_time_ms / 1000.0) << " seconds"; } else { - os << Colors::YELLOW << std::fixed << std::setprecision(2) << total_time_ms << " ms" << Colors::RESET; + os << std::fixed << std::setprecision(2) << total_time_ms << " ms"; } + os << Colors::RESET; os << "\n"; print_separator(os, true); diff --git a/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp b/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp index c4b71f383166..a202c28a4880 100644 --- a/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp +++ b/barretenberg/cpp/src/barretenberg/common/bb_bench.hpp @@ -52,6 +52,7 @@ struct AggregateEntry { std::size_t count = 0; size_t num_threads = 0; double time_mean = 0; + std::size_t time_max = 0; double time_stddev = 0; // Welford's algorithm state diff --git a/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp b/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp index c923e3e7a62f..b60b2f732bce 100644 --- a/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp +++ b/barretenberg/cpp/src/barretenberg/common/parallel_for_mutex_pool.cpp @@ -41,6 +41,7 @@ class ThreadPool { do_iterations(); { + BB_BENCH_NAME("spinning main thread"); std::unique_lock lock(tasks_mutex); complete_condition_.wait(lock, [this] { return complete_ == num_iterations_; }); } @@ -71,6 +72,7 @@ class ThreadPool { } iteration = iteration_++; } + BB_BENCH_NAME("do_iterations()"); task_(iteration); { std::unique_lock lock(tasks_mutex); diff --git a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp index bbb99d27e840..36a1e15ea90c 100644 --- a/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp +++ b/barretenberg/cpp/src/barretenberg/ecc/groups/element_impl.hpp @@ -795,7 +795,7 @@ std::vector> element::batch_mul_with_endomo const std::span>& points, const Fr& scalar) noexcept { BB_BENCH(); - typedef affine_element affine_element; + using affine_element = affine_element; const size_t num_points = points.size(); // Space for temporary values @@ -833,20 +833,6 @@ std::vector> element::batch_mul_with_endomo } }; - /** - * @brief Perform batch affine addition in parallel - * - */ - const auto batch_affine_add_internal = - [num_points, &scratch_space, &batch_affine_add_chunked](const affine_element* lhs, affine_element* rhs) { - parallel_for_heuristic( - num_points, - [&](size_t start, size_t end, BB_UNUSED size_t chunk_index) { - batch_affine_add_chunked(lhs + start, rhs + start, end - start, &scratch_space[0] + start); - }, - thread_heuristics::FF_ADDITION_COST * 6 + thread_heuristics::FF_MULTIPLICATION_COST * 6); - }; - /** * @brief Perform point doubling lhs[i]=lhs[i]+lhs[i] with batch inversion * @@ -878,18 +864,6 @@ std::vector> element::batch_mul_with_endomo lhs[i].y = personal_scratch_space[i] * (temp - lhs[i].x) - lhs[i].y; } }; - /** - * @brief Perform point doubling in parallel - * - */ - const auto batch_affine_double = [num_points, &scratch_space, &batch_affine_double_chunked](affine_element* lhs) { - parallel_for_heuristic( - num_points, - [&](size_t start, size_t end, BB_UNUSED size_t chunk_index) { - batch_affine_double_chunked(lhs + start, end - start, &scratch_space[0] + start); - }, - thread_heuristics::FF_ADDITION_COST * 7 + thread_heuristics::FF_MULTIPLICATION_COST * 6); - }; // We compute the resulting point through WNAF by evaluating (the (\sum_i (16ⁱ⋅ // (a_i ∈ {-15,-13,-11,-9,-7,-5,-3,-1,1,3,5,7,9,11,13,15}))) - skew), where skew is 0 or 1. The result of the sum is @@ -916,50 +890,58 @@ std::vector> element::batch_mul_with_endomo constexpr size_t LOOKUP_SIZE = 8; constexpr size_t NUM_ROUNDS = 32; + + detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar); + detail::EndomorphismWnaf wnaf{ endo_scalars }; + + std::vector work_elements(num_points); std::array, LOOKUP_SIZE> lookup_table; for (auto& table : lookup_table) { table.resize(num_points); } - // Initialize first etnries in lookup table std::vector temp_point_vector(num_points); - parallel_for_heuristic( - num_points, - [&](size_t i) { - // If the point is at infinity we fix-up the result later - // To avoid 'trying to invert zero in the field' we set the point to 'one' here - temp_point_vector[i] = points[i].is_point_at_infinity() ? affine_element::one() : points[i]; - lookup_table[0][i] = points[i].is_point_at_infinity() ? affine_element::one() : points[i]; - }, - thread_heuristics::FF_COPY_COST * 2); - - // Construct lookup table - batch_affine_double(&temp_point_vector[0]); - for (size_t j = 1; j < LOOKUP_SIZE; ++j) { - parallel_for_heuristic( - num_points, - [&](size_t i) { lookup_table[j][i] = lookup_table[j - 1][i]; }, - thread_heuristics::FF_COPY_COST); - batch_affine_add_internal(&temp_point_vector[0], &lookup_table[j][0]); - } - detail::EndoScalars endo_scalars = Fr::split_into_endomorphism_scalars(converted_scalar); - detail::EndomorphismWnaf wnaf{ endo_scalars }; + auto execute_range = [&](size_t start, size_t end) { + // Perform batch affine addition in parallel + const auto add_chunked = [&](const affine_element* lhs, affine_element* rhs) { + batch_affine_add_chunked(&lhs[start], &rhs[start], end - start, &scratch_space[start]); + }; - std::vector work_elements(num_points); + // Perform point doubling in parallel + const auto double_chunked = [&](affine_element* lhs) { + batch_affine_double_chunked(&lhs[start], end - start, &scratch_space[start]); + }; - constexpr Fq beta = Fq::cube_root_of_unity(); - uint64_t wnaf_entry = 0; - uint64_t index = 0; - bool sign = 0; - // Prepare elements for the first batch addition - for (size_t j = 0; j < 2; ++j) { - wnaf_entry = wnaf.table[j]; - index = wnaf_entry & 0x0fffffffU; - sign = static_cast((wnaf_entry >> 31) & 1); - const bool is_odd = ((j & 1) == 1); - parallel_for_heuristic( - num_points, - [&](size_t i) { + // Initialize first entries in lookup table + for (size_t i = start; i < end; ++i) { + if (points[i].is_point_at_infinity()) { + temp_point_vector[i] = affine_element::one(); + lookup_table[0][i] = affine_element::one(); + } else { + temp_point_vector[i] = points[i]; + lookup_table[0][i] = points[i]; + } + } + // Costruct lookup table + double_chunked(&temp_point_vector[0]); + for (size_t j = 1; j < LOOKUP_SIZE; ++j) { + for (size_t i = start; i < end; ++i) { + lookup_table[j][i] = lookup_table[j - 1][i]; + } + add_chunked(&temp_point_vector[0], &lookup_table[j][0]); + } + + constexpr Fq beta = Fq::cube_root_of_unity(); + uint64_t wnaf_entry = 0; + uint64_t index = 0; + bool sign = 0; + // Prepare elements for the first batch addition + for (size_t j = 0; j < 2; ++j) { + wnaf_entry = wnaf.table[j]; + index = wnaf_entry & 0x0fffffffU; + sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((j & 1) == 1); + for (size_t i = start; i < end; ++i) { auto to_add = lookup_table[static_cast(index)][i]; to_add.y.self_conditional_negate(sign ^ is_odd); if (is_odd) { @@ -970,64 +952,51 @@ std::vector> element::batch_mul_with_endomo } else { temp_point_vector[i] = to_add; } - }, - (is_odd ? thread_heuristics::FF_MULTIPLICATION_COST : 0) + thread_heuristics::FF_COPY_COST + - thread_heuristics::FF_ADDITION_COST); - } - // First cycle of addition - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - // Run through SM logic in wnaf form (excluding the skew) - for (size_t j = 2; j < NUM_ROUNDS * 2; ++j) { - wnaf_entry = wnaf.table[j]; - index = wnaf_entry & 0x0fffffffU; - sign = static_cast((wnaf_entry >> 31) & 1); - const bool is_odd = ((j & 1) == 1); - if (!is_odd) { - for (size_t k = 0; k < 4; ++k) { - batch_affine_double(&work_elements[0]); } } - parallel_for_heuristic( - num_points, - [&](size_t i) { + add_chunked(&temp_point_vector[0], &work_elements[0]); + // Run through SM logic in wnaf form (excluding the skew) + for (size_t j = 2; j < NUM_ROUNDS * 2; ++j) { + wnaf_entry = wnaf.table[j]; + index = wnaf_entry & 0x0fffffffU; + sign = static_cast((wnaf_entry >> 31) & 1); + const bool is_odd = ((j & 1) == 1); + if (!is_odd) { + for (size_t k = 0; k < 4; ++k) { + double_chunked(&work_elements[0]); + } + } + for (size_t i = start; i < end; ++i) { auto to_add = lookup_table[static_cast(index)][i]; to_add.y.self_conditional_negate(sign ^ is_odd); if (is_odd) { to_add.x *= beta; } temp_point_vector[i] = to_add; - }, - (is_odd ? thread_heuristics::FF_MULTIPLICATION_COST : 0) + thread_heuristics::FF_COPY_COST + - thread_heuristics::FF_ADDITION_COST); - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - } - - // Apply skew for the first endo scalar - if (wnaf.skew) { - parallel_for_heuristic( - num_points, - [&](size_t i) { temp_point_vector[i] = -lookup_table[0][i]; }, - thread_heuristics::FF_ADDITION_COST + thread_heuristics::FF_COPY_COST); - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - } - // Apply skew for the second endo scalar - if (wnaf.endo_skew) { - parallel_for_heuristic( - num_points, - [&](size_t i) { + } + add_chunked(&temp_point_vector[0], &work_elements[0]); + } + // Apply skew for the first endo scalar + if (wnaf.skew) { + for (size_t i = start; i < end; ++i) { + temp_point_vector[i] = -lookup_table[0][i]; + } + add_chunked(&temp_point_vector[0], &work_elements[0]); + } + // Apply skew for the second endo scalar + if (wnaf.endo_skew) { + for (size_t i = start; i < end; ++i) { temp_point_vector[i] = lookup_table[0][i]; temp_point_vector[i].x *= beta; - }, - thread_heuristics::FF_MULTIPLICATION_COST + thread_heuristics::FF_COPY_COST); - batch_affine_add_internal(&temp_point_vector[0], &work_elements[0]); - } - // handle points at infinity explicitly - parallel_for_heuristic( - num_points, - [&](size_t i) { + } + add_chunked(&temp_point_vector[0], &work_elements[0]); + } + // handle points at infinity explicitly + for (size_t i = start; i < end; ++i) { work_elements[i] = points[i].is_point_at_infinity() ? work_elements[i].set_infinity() : work_elements[i]; - }, - thread_heuristics::FF_COPY_COST); + } + }; + parallel_for_range(num_points, execute_range); return work_elements; }