Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
Summary:
This diff fixes two bugs that I found when creating a custom op and comparing results to PyTorch python implementations (next diff).

1) There is a segfault that occurred when n % 8 != 0 because the ukernel was storing out of bounds.  There was an existing test for this case, but it passed because the output shape in the test was mistakenly too big and so no out of bound memory was written to in the test (it had shape m x k instead of shape m x n).  This diff fixes the out-of-bound writes and the existing test.

2) The find_min_and_max function was incorrect.  This corrects the function and adds tests for the reduction functions (find_min_and_max and compute_sum).  (The find_min_and_max function is only used for dynamic quantization; there are existing tests for the quantization, but they passed because the existing find_min_and_max happened to return correct results in the tested case.)

Reviewed By: digantdesai

Differential Revision: D60773448
  • Loading branch information
metascroy authored and facebook-github-bot committed Aug 20, 2024
1 parent 1909171 commit d29a33e
Show file tree
Hide file tree
Showing 10 changed files with 246 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,21 @@ void kernel_impl(
if constexpr (has_clamp) {
res = clamp(res, clamp_min, clamp_max);
}
vst1q_f32(output + m_idx * output_m_stride + n_idx, res);

// Store result
int remaining = n - n_idx;
float* store_loc = output + m_idx * output_m_stride + n_idx;
if (remaining >= 4) {
vst1q_f32(store_loc, res);
} else if (remaining >= 3) {
vst1_f32(store_loc, vget_low_f32(res));
*(store_loc + 2) = res[2];
} else if (remaining >= 2) {
vst1_f32(store_loc, vget_low_f32(res));
} else {
*(store_loc) = res[0];
}

} // n_idx
activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr);
} // m_idx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,34 @@ void kernel_impl(
res_0123 = vec_clamp(res_0123, vec_min, vec_max);
res_4567 = vec_clamp(res_4567, vec_min, vec_max);
}
vst1q_f32(output + m_idx * output_m_stride + n_idx, res_0123);
vst1q_f32(output + m_idx * output_m_stride + n_idx + 4, res_4567);

// Store result
int remaining = n - n_idx;
float* store_loc = output + m_idx * output_m_stride + n_idx;
if (remaining >= 8) {
vst1q_f32(store_loc, res_0123);
vst1q_f32(store_loc + 4, res_4567);
} else if (remaining >= 7) {
vst1q_f32(store_loc, res_0123);
vst1_f32(store_loc + 4, vget_low_f32(res_4567));
*(store_loc + 6) = res_4567[2];
} else if (remaining >= 6) {
vst1q_f32(store_loc, res_0123);
vst1_f32(store_loc + 4, vget_low_f32(res_4567));
} else if (remaining >= 5) {
vst1q_f32(store_loc, res_0123);
*(store_loc + 4) = res_4567[0];
} else if (remaining >= 4) {
vst1q_f32(store_loc, res_0123);
} else if (remaining >= 3) {
vst1_f32(store_loc, vget_low_f32(res_0123));
*(store_loc + 2) = res_0123[2];
} else if (remaining >= 2) {
vst1_f32(store_loc, vget_low_f32(res_0123));
} else {
*store_loc = res_0123[0];
}

} // n_idx
activation_data_byte_ptr += (activation_ptr - activation_data_byte_ptr);
} // m_idx
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <cassert>

int32_t torchao::kernels::cpu::aarch64::reduction::compute_sum(
const int8_t* vals,
int size) {
assert(size >= 1);

int32_t res = 0;
int i = 0;

#pragma unroll(4)
for (; i < size; i += 16) {
for (; i + 15 < size; i += 16) {
int8x16_t vec_vals = vld1q_s8(vals + i);
res += (int)(vaddlvq_s8(vec_vals));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
// (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

#include <torchao/experimental/kernels/cpu/aarch64/reduction/reduction.h>
#include <cassert>

void torchao::kernels::cpu::aarch64::reduction::find_min_and_max(
float32_t& min,
float32_t& max,
const float32_t* vals,
int size) {
float32x4_t mins = vdupq_n_f32(0.0);
float32x4_t maxes = vdupq_n_f32(0.0);
assert(size > 0);

// Needed in case size < 4 so we don't compare to
// uninitialized min/max values
min = vals[0];
max = min;

int i = 0;
for (; i < size; i += 8) {
float32x4_t v1 = vld1q_f32(vals + i);
float32x4_t v2 = vld1q_f32(vals + i + 4);
mins = vminq_f32(v1, v2);
maxes = vmaxq_f32(v1, v2);
if (i + 3 < size) {
float32x4_t mins = vld1q_f32(vals + i);
float32x4_t maxes = mins;
i += 4;
for (; i + 3 < size; i += 4) {
float32x4_t v = vld1q_f32(vals + i);
mins = vminq_f32(mins, v);
maxes = vmaxq_f32(maxes, v);
}
min = vminvq_f32(mins);
max = vmaxvq_f32(maxes);
}
min = vminvq_f32(mins);
max = vmaxvq_f32(maxes);

// Remainder
while (i < size) {
Expand Down
9 changes: 9 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ target_link_libraries(
dep
)

add_executable(test_reduction test_reduction.cpp)
target_link_libraries(
test_reduction
PRIVATE
GTest::gtest_main
dep
)

add_executable(test_bitpacking test_bitpacking.cpp)
target_link_libraries(
test_bitpacking
Expand All @@ -61,6 +69,7 @@ target_link_libraries(

include(GoogleTest)
gtest_discover_tests(test_quantization)
gtest_discover_tests(test_reduction)
gtest_discover_tests(test_bitpacking)
gtest_discover_tests(test_linear)
gtest_discover_tests(test_valpacking)
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} -S ${TORCHAO_LIBRARIES}/torchao/e
cmake --build ${CMAKE_OUT}

# Run
${CMAKE_OUT}/test_quantization
${CMAKE_OUT}/test_bitpacking
${CMAKE_OUT}/test_linear
${CMAKE_OUT}/test_valpacking
${CMAKE_OUT}/test_quantization
${CMAKE_OUT}/test_reduction
${CMAKE_OUT}/test_bitpacking
${CMAKE_OUT}/test_linear
${CMAKE_OUT}/test_valpacking
Loading

0 comments on commit d29a33e

Please sign in to comment.