Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions onnxruntime/core/providers/cpu/math/det.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

#include "core/providers/cpu/math/det.h"
#include "core/util/math_cpuonly.h"
// TODO: fix the warnings
#if defined(_MSC_VER) && !defined(__clang__)
// Chance of arithmetic overflow could be reduced
#pragma warning(disable : 26451)
#endif
#include "core/common/narrow.h"

using namespace onnxruntime::common;

Expand All @@ -33,7 +29,7 @@ Status Det<T>::Compute(OpKernelContext* context) const {
ORT_ENFORCE(X != nullptr);

const auto& X_shape = X->Shape();
int X_num_dims = static_cast<int>(X_shape.NumDimensions());
size_t X_num_dims = X_shape.NumDimensions();

// input validation
if (X_num_dims < 2) { // this is getting capture by shape inference code as well
Expand All @@ -44,10 +40,11 @@ Status Det<T>::Compute(OpKernelContext* context) const {
}

const auto* X_data = X->Data<T>();
int matrix_dim = static_cast<int>(X_shape[X_num_dims - 1]);
int64_t matrix_dim = X_shape[X_num_dims - 1];

auto get_determinant = [matrix_dim](const T* matrix_ptr) -> T {
auto one_eigen_mat = ConstEigenMatrixMapRowMajor<T>(matrix_ptr, matrix_dim, matrix_dim);
auto one_eigen_mat = ConstEigenMatrixMapRowMajor<T>(
matrix_ptr, onnxruntime::narrow<Eigen::Index>(matrix_dim), onnxruntime::narrow<Eigen::Index>(matrix_dim));
return one_eigen_mat.determinant();
};

Expand All @@ -60,15 +57,15 @@ Status Det<T>::Compute(OpKernelContext* context) const {
std::vector<int64_t> output_shape;
output_shape.reserve(X_num_dims - 2);
int64_t batch_size = 1;
for (int i = 0; i < X_num_dims - 2; ++i) {
for (size_t i = 0; i < X_num_dims - 2; ++i) {
batch_size *= X_shape[i];
output_shape.push_back(X_shape[i]);
}

int num_matrix_elems = matrix_dim * matrix_dim;
int64_t num_matrix_elems = matrix_dim * matrix_dim;
auto* Y = context->Output(0, output_shape);
auto* Y_data = Y->MutableData<T>();
for (int b = 0; b < static_cast<int>(batch_size); ++b) { // can be parallelized if need to
for (int64_t b = 0; b < batch_size; ++b) { // can be parallelized if need to
const T* one_matrix = X_data + (b * num_matrix_elems);
*Y_data++ = get_determinant(one_matrix);
}
Expand Down
Loading