diff --git a/src/hsort_impl.h b/src/hsort_impl.h index a6ff32ab35..24fc59add6 100644 --- a/src/hsort_impl.h +++ b/src/hsort_impl.h @@ -23,20 +23,22 @@ static SECP256K1_INLINE size_t secp256k1_heap_child2(size_t i) { return secp256k1_heap_child1(i)+1; } -static SECP256K1_INLINE void secp256k1_heap_swap64(unsigned char *a, size_t i, size_t j, size_t stride) { +static SECP256K1_INLINE void secp256k1_heap_swap64(unsigned char *a, size_t i, size_t j, size_t stride, size_t swap_size) { unsigned char tmp[64]; - VERIFY_CHECK(stride <= 64); - memcpy(tmp, a + i*stride, stride); - memmove(a + i*stride, a + j*stride, stride); - memcpy(a + j*stride, tmp, stride); + VERIFY_CHECK(swap_size <= 64); + memcpy(tmp, a + i*stride, swap_size); + memmove(a + i*stride, a + j*stride, swap_size); + memcpy(a + j*stride, tmp, swap_size); } +/* Swap the elements of a at indices i and j, assuming that the size of each element is stride. */ static SECP256K1_INLINE void secp256k1_heap_swap(unsigned char *a, size_t i, size_t j, size_t stride) { - while (64 < stride) { - secp256k1_heap_swap64(a + (stride - 64), i, j, 64); - stride -= 64; + size_t remaining = stride; + while (64 < remaining) { + secp256k1_heap_swap64(a + (remaining - 64), i, j, stride, 64); + remaining -= 64; } - secp256k1_heap_swap64(a, i, j, stride); + secp256k1_heap_swap64(a, i, j, stride, remaining); } static SECP256K1_INLINE void secp256k1_heap_down(unsigned char *a, size_t i, size_t heap_size, size_t stride, diff --git a/src/tests.c b/src/tests.c index 4d3f22e16a..39407b0a20 100644 --- a/src/tests.c +++ b/src/tests.c @@ -6607,6 +6607,18 @@ static void run_pubkey_comparison(void) { CHECK(secp256k1_ec_pubkey_cmp(CTX, &pk2, &pk1) > 0); } +static void test_heap_swap(void) { + unsigned char a[600]; + unsigned char e[sizeof(a)]; + memset(a, 21, 200); + memset(a + 200, 99, 200); + memset(a + 400, 42, 200); + memset(e, 42, 200); + memset(e + 200, 99, 200); + memset(e + 400, 21, 200); + secp256k1_heap_swap(a, 0, 2, 200); + CHECK(secp256k1_memcmp_var(a, e, sizeof(a)) == 0); +} static void test_hsort_is_sorted(int *ints, size_t n) { size_t i; @@ -6801,6 +6813,7 @@ static void run_pubkey_sort(void) { test_sort_api(); test_sort(); test_sort_vectors(); + test_heap_swap(); }