Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MXNET-244] Work around likely compiler bug on nested inlines and tem…
Browse files Browse the repository at this point in the history
…porary acces… (#13535)

* Work around likely compiler bug on nested inlines and temporary access to stream

* Don't compile khatri_rao tests if we don't have LAPACK

* Address CR comment
  • Loading branch information
larroy authored and szha committed Jan 7, 2019
1 parent c63ef9a commit 6dae0bf
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 24 deletions.
72 changes: 72 additions & 0 deletions src/operator/c_lapack_api.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "c_lapack_api.h"

#if (MSHADOW_USE_MKL && MXNET_USE_LAPACK)
#elif MXNET_USE_LAPACK
#else
// use pragma message instead of warning
#pragma message("Warning: lapack usage not enabled, linalg-operators will not be available." \
" Ensure that lapack library is installed and build with USE_LAPACK=1 to get lapack" \
" functionalities.")

// Define compilable stubs.
#define MXNET_LAPACK_CWRAPPER1(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

#define MXNET_LAPACK_UNAVAILABLE(func) \
int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
MXNET_LAPACK_CWRAPPER1(spotrf, float)
MXNET_LAPACK_CWRAPPER1(dpotrf, double)
MXNET_LAPACK_CWRAPPER1(spotri, float)
MXNET_LAPACK_CWRAPPER1(dpotri, double)

MXNET_LAPACK_UNAVAILABLE(sposv)
MXNET_LAPACK_UNAVAILABLE(dposv)

MXNET_LAPACK_CWRAPPER2(sgelqf, float)
MXNET_LAPACK_CWRAPPER2(dgelqf, double)
MXNET_LAPACK_CWRAPPER2(sorglq, float)
MXNET_LAPACK_CWRAPPER2(dorglq, double)

MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)
#endif // MSHADOW_USE_MKL == 0
35 changes: 11 additions & 24 deletions src/operator/c_lapack_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -324,42 +324,26 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

#else

// use pragma message instead of warning
#pragma message("Warning: lapack usage not enabled, linalg-operators will not be available." \
" Ensure that lapack library is installed and build with USE_LAPACK=1 to get lapack" \
" functionalities.")


#define MXNET_LAPACK_ROW_MAJOR 101
#define MXNET_LAPACK_COL_MAJOR 102

// Define compilable stubs.
#define MXNET_LAPACK_CWRAPPER1(func, dtype) \
inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype* a, int lda);

#define MXNET_LAPACK_CWRAPPER2(func, dtype) \
inline int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
int MXNET_LAPACK_##func(int matrix_layout, int m, int n, dtype* a, \
int lda, dtype* tau, dtype* work, int lwork);

#define MXNET_LAPACK_CWRAPPER3(func, dtype) \
inline int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int MXNET_LAPACK_##func(int matrix_layout, char uplo, int n, dtype *a, \
int lda, dtype *w, dtype *work, int lwork, \
int *iwork, int liwork) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}
int *iwork, int liwork);

#define MXNET_LAPACK_UNAVAILABLE(func) \
inline int mxnet_lapack_##func(...) { \
LOG(FATAL) << "MXNet build without lapack. Function " << #func << " is not available."; \
return 1; \
}

int mxnet_lapack_##func(...);
MXNET_LAPACK_CWRAPPER1(spotrf, float)
MXNET_LAPACK_CWRAPPER1(dpotrf, double)
MXNET_LAPACK_CWRAPPER1(spotri, float)
Expand All @@ -375,7 +359,10 @@ inline void flip(int m, int n, DType *b, int ldb, DType *a, int lda) {

MXNET_LAPACK_CWRAPPER3(ssyevd, float)
MXNET_LAPACK_CWRAPPER3(dsyevd, double)

#undef MXNET_LAPACK_CWRAPPER1
#undef MXNET_LAPACK_CWRAPPER2
#undef MXNET_LAPACK_CWRAPPER3
#undef MXNET_LAPACK_UNAVAILABLE
#endif

template <typename DType>
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/operator/krprod_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,8 @@ TEST(row_wise_kronecker, FourInputMatrices) {
FreeSpace(&result);
}


#if MXNET_USE_LAPACK == 1
TEST(khatri_rao, OneInputMatrix) {
// Input matrices of shape (2, 4) which is also the expected result
DType mat[8] {1, 2, 3, 4, 5, 6, 7, 8};
Expand Down Expand Up @@ -444,5 +446,6 @@ TEST(inv_khatri_rao, ThreeInputMatricesTranposed) {
FreeSpace(&kr_t);
FreeSpace(&actual_dot);
}
#endif // MXNET_USE_LAPACK == 1
} // namespace op
} // namespace mxnet

0 comments on commit 6dae0bf

Please sign in to comment.