Skip to content

Commit

Permalink
Merge branch 'zikelim/ExtensionFP16' into 'master'
Browse files Browse the repository at this point in the history
Extend missing implementations for float16 support

See merge request walberla/walberla!643
  • Loading branch information
modkin committed Dec 19, 2023
2 parents af126dc + 15f2030 commit 686862b
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 31 deletions.
50 changes: 28 additions & 22 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -1255,28 +1255,34 @@ endif()
## Half precision
##
############################################################################################################################
if (WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT)
if (WALBERLA_CXX_COMPILER_IS_GNU OR WALBERLA_CXX_COMPILER_IS_CLANG)
message(STATUS "Configuring with *experimental* half precision (float16) support. You better know what you are doing.")
if (WALBERLA_CXX_COMPILER_IS_GNU AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 12.0.0)
message(WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
"Half precision support for gcc has only been tested with version >= 12. "
"You are using a previous version - it may not work correctly.")
endif ()
if (WALBERLA_CXX_COMPILER_IS_CLANG AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 15.0.0)
message(WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
"Half precision support for clang has only been tested with version >= 15. "
"You are using a previous version - it may not work correctly.")
endif ()
if (NOT WALBERLA_OPTIMIZE_FOR_LOCALHOST)
message(WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
"You are not optimizing for localhost. You may encounter linker errors, or WORSE: silent incorrect fp16 arithmetic! Consider also enabling WALBERLA_OPTIMIZE_FOR_LOCALHOST!")
endif ()
else ()
message(FATAL_ERROR "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
"Half precision support is currently only available for gcc and clang.")
endif ()
endif ()
if ( WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT )
### Compiler requirements:
### Within this project, there are several checks to ensure that the template parameter 'ValueType'
### is a floating point number. The check is_floating_point<ValueType> is done primarily in our MPI implementation.
### The IEE 754 floating type format _Float16, evaluates to true only if your compiler supports the
### open C++23 standard P1467R9 (Extended floating-point types and standard names).
### Compare:
### https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p1467r9.html
###
### Right now (18.12.2023) this is the case only for gcc13.
### For more information see:
### https://gcc.gnu.org/projects/cxx-status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names
### https://clang.llvm.org/cxx_status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names

try_compile( WALBERLA_SUPPORT_HALF_PRECISION "${CMAKE_CURRENT_BINARY_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/cmake/TestFloat16.cpp"
CXX_STANDARD 23 OUTPUT_VARIABLE TRY_COMPILE_OUTPUT )
## message( STATUS ${TRY_COMPILE_OUTPUT} )
if ( NOT WALBERLA_SUPPORT_HALF_PRECISION )
message( FATAL_ERROR "Compiler: ${CMAKE_CXX_COMPILER} Version: ${CMAKE_CXX_COMPILER_VERSION} does not support half precision" )
endif ()

# Local host optimization
if ( NOT WALBERLA_OPTIMIZE_FOR_LOCALHOST )
message( WARNING "[WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT] "
"You are not optimizing for localhost. You may encounter linker errors, or WORSE: silent incorrect fp16 arithmetic! Consider also enabling WALBERLA_OPTIMIZE_FOR_LOCALHOST!" )
endif () # Local host check

endif () # Check if WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT is set

############################################################################################################################
# Documentation Generation
Expand Down
7 changes: 7 additions & 0 deletions cmake/TestFloat16.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#include <iostream>


int main()
{
static_assert(std::is_floating_point_v<_Float16>);
}
9 changes: 9 additions & 0 deletions src/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@ add_library( core )
if( MPI_FOUND )
target_link_libraries( core PUBLIC MPI::MPI_CXX )
endif()

if ( WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT )
# Actual support for float16 is available only since C++23
# before is_arithmetic and is_floating_point evaluated to false,
# also many STL functions are compatible with float16 only since C++23.
# Which features are actually supported depend on the compiler
target_compile_features(core PUBLIC cxx_std_23)
endif ()

target_link_libraries( core PUBLIC ${SERVICE_LIBS} )
target_sources( core
PRIVATE
Expand Down
5 changes: 5 additions & 0 deletions src/core/DataTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ namespace walberla {

namespace real_comparison
{
#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
// const bfloat16 Epsilon< bfloat16 >::value = static_cast< bfloat16 >(1e-2); // machine eps is 2^-7
const float16 Epsilon< float16 >::value = static_cast< float16 >(1e-3); // machine eps is 2^-10
// Note, depending on the kind of float16 <bfloat, float16> another Epsilon must be used.
#endif
const float Epsilon< float >::value = static_cast< float >(1e-4);
const double Epsilon< double >::value = static_cast< double >(1e-8);
const long double Epsilon< long double >::value = static_cast< long double >(1e-10);
Expand Down
42 changes: 33 additions & 9 deletions src/core/DataTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,21 +175,33 @@ using real_t = float;
/// Only bandwidth bound code may therefore benefit. None of this is guaranteed, and may change in the future.
///
#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
# if defined(WALBERLA_CXX_COMPILER_IS_CLANG) || defined(WALBERLA_CXX_COMPILER_IS_GNU)
/// Clang version must be 15 or higher for x86 half precision support.
/// GCC version must be 12 or higher for x86 half precision support.
/// Also support seems to require SSE, so ensure that respective instruction sets are enabled.
/// FIXME: (not really right) Clang version must be 15 or higher for x86 half precision support.
/// FIXME: (not really right) GCC version must be 12 or higher for x86 half precision support.
/// FIXME: (I don't know) Also support seems to require SSE, so ensure that respective instruction sets are enabled.
/// See
/// https://clang.llvm.org/docs/LanguageExtensions.html#half-precision-floating-point
/// https://gcc.gnu.org/onlinedocs/gcc/Half-Precision.html
/// for more information.
/// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
/// Compiler requirements:
/// Within this project, there are several checks to ensure that the template parameter 'ValueType'
/// is a floating point number. The check is_floating_point<ValueType> is done primarily in our MPI implementation.
/// The IEE 754 floating type format _Float16, evaluates to true only if your compiler supports the
/// open C++23 standard P1467R9 (Extended floating-point types and standard names).
/// Compare:
/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p1467r9.html
///
/// Right now (18.12.2023) this is the case only for gcc13.
/// For more information see:
/// https://gcc.gnu.org/projects/cxx-status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names
/// https://clang.llvm.org/cxx_status.html#:~:text=Extended%20floating%2Dpoint%20types%20and%20standard%20names

using half = _Float16;
// Note: there are two possible float16 formats.
// The one used right now is the IEE 754 float16 standard, consisting of a 5 bit exponent and a 10 bit mantissa.
// Another possible half precision format would be the one from Google Brain (bfloat16) with an 8 bit exponent and a 7 bit mantissa.
// Compare https://i10git.cs.fau.de/ab04unyc/walberla/-/issues/23
using float16 = half;
# else
static_assert(false, "\n\n### Attempting to built walberla with half precision support.\n"
"### However, the compiler you chose is not suited for that, or we simply have not implemented "
"support for half precision and your compiler.\n");
# endif
#endif
using float32 = float;
using float64 = double;
Expand Down Expand Up @@ -228,6 +240,10 @@ inline bool realIsIdentical( const real_t a, const real_t b )
namespace real_comparison
{
template< class T > struct Epsilon;
#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
using walberla::float16;
template<> struct Epsilon< float16 > { static const float16 value; };
#endif
template<> struct Epsilon< float > { static const float value; };
template<> struct Epsilon< double > { static const double value; };
template<> struct Epsilon< long double > { static const long double value; };
Expand All @@ -254,6 +270,14 @@ inline bool floatIsEqual( float lhs, float rhs, const float epsilon = real_compa
return std::fabs( lhs - rhs ) < epsilon;
}

#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
inline bool floatIsEqual( walberla::float16 lhs, walberla::float16 rhs, const walberla::float16 epsilon = real_comparison::Epsilon<walberla::float16>::value )
{
const auto difference = lhs - rhs;
return ( (difference < 0) ? -difference : difference ) < epsilon;
}
#endif

} // namespace walberla

#define WALBERLA_UNUSED(x) (void)(x)
Expand Down
3 changes: 3 additions & 0 deletions src/core/mpi/MPIWrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,9 @@ WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned short int , MPI_UNSIGNED_SHORT
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned int , MPI_UNSIGNED );
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned long int , MPI_UNSIGNED_LONG );
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( unsigned long long , MPI_UNSIGNED_LONG_LONG );
#ifdef WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( walberla::float16 , MPI_WCHAR );
#endif
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( float , MPI_FLOAT );
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( double , MPI_DOUBLE );
WALBERLA_CREATE_MPITRAIT_SPECIALIZATION( long double , MPI_LONG_DOUBLE );
Expand Down
14 changes: 14 additions & 0 deletions tests/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,17 @@ if( WALBERLA_BUILD_WITH_PARMETIS )
waLBerla_compile_test( NAME PlainParMetisTest FILES load_balancing/PlainParMetisTest.cpp )
waLBerla_execute_test( NAME PlainParMetisTest PROCESSES 3 )
endif()

###################
# Mixed Precision #
###################

if ( WALBERLA_BUILD_WITH_HALF_PRECISION_SUPPORT )
waLBerla_compile_test( Name Float16SupportTest FILES Float16SupportTest.cpp DEPENDS core)
# Actual support for float16 is available only since C++23
# before is_arithmetic and is_floating_point evaluated to false,
# also many STL functions are compatible with float16 only since C++23.
# Which features are actually supported depend on the compiler
target_compile_features( Float16SupportTest PUBLIC cxx_std_23 )
waLBerla_execute_test(NAME Float16SupportTest)
endif ()
163 changes: 163 additions & 0 deletions tests/core/Float16SupportTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
//======================================================================================================================
//
// This file is part of waLBerla. waLBerla is free software: you can
// redistribute it and/or modify it under the terms of the GNU General Public
// License as published by the Free Software Foundation, either version 3 of
// the License, or (at your option) any later version.
//
// waLBerla is distributed in the hope that it will be useful, but WITHOUT
// ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
// FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
// for more details.
//
// You should have received a copy of the GNU General Public License along
// with waLBerla (see COPYING.txt). If not, see <http://www.gnu.org/licenses/>.
//
//! \file Float16SupportTest.cpp
//! \ingroup core
//! \author Michael Zikeli <[email protected]>
//
//======================================================================================================================

#include <memory>
#include <numeric>

#include "core/DataTypes.h"
#include "core/Environment.h"
#include "core/logging/Logging.h"

namespace walberla::simple_Float16_test {
using walberla::floatIsEqual;
using walberla::real_t;
using walberla::uint_c;
using walberla::uint_t;

// === Choosing Accuracy ===
//+++ Precision : fp16 +++
using walberla::float16;
using walberla::float32;
using walberla::float64;
using dst_t = float16;
using src_t = real_t;
constexpr real_t precisionLimit = walberla::float16( 1e-3 );
const std::string precisionType = "float16";
constexpr const auto maxLevel = uint_t( 3 );

void simple_array_test()
{
auto fpSrc = std::make_shared< src_t[] >( 10 );
auto fpDst = std::make_shared< dst_t[] >( 10 );

std::fill_n( fpSrc.get(), 10, 17. );
std::fill_n( fpDst.get(), 10, (dst_t) 17. );

fpSrc[5] = 8.;
fpDst[5] = (dst_t) 8.;

// Test equality with custom compare
WALBERLA_CHECK_LESS( std::fabs( fpSrc[9] - (src_t) fpDst[9] ), precisionLimit );
WALBERLA_CHECK_LESS( std::fabs( fpSrc[5] - (src_t) fpDst[5] ), precisionLimit );
// Test specialized floatIsEqual
WALBERLA_CHECK( floatIsEqual( fpSrc[9], (src_t) fpDst[9], (src_t) precisionLimit ) );
WALBERLA_CHECK( floatIsEqual( (dst_t) fpSrc[9], fpDst[9], (dst_t) precisionLimit ) );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], fpDst[9] );

// Test std::fill_n
auto other_fpDst = std::make_shared< dst_t[] >( 10 );
std::fill_n( other_fpDst.get(), 10, (dst_t) 2. );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 2., other_fpDst[9] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 2., other_fpDst[5] );

// Test std::swap
std::swap( fpDst, other_fpDst );
fpDst[5] = (dst_t) 9.;

WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], other_fpDst[9] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[5], other_fpDst[5] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 2., fpDst[9] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) 9., fpDst[5] );

} // simple_Float16_test::simple_array_test()

void vector_test()
{
auto fpSrc = std::vector< src_t >( 10 );
auto fpDst_cast = std::vector< dst_t >( 10 );
auto fp32 = std::vector< walberla::float32 >( 10 );
auto fpDst = std::vector< dst_t >( 10 );

fpSrc.assign( 10, 1.5 );
fpDst_cast.assign( 10, (dst_t) 1.5 );
fp32.assign( 10, 1.5f );
std::copy( fpSrc.begin(), fpSrc.end(), fpDst.begin() );
WALBERLA_LOG_WARNING_ON_ROOT(
" std::vector.assign is not able to assign "
<< typeid( src_t ).name() << " values to container of type " << precisionType << ".\n"
<< " Therefore, the floating point value for assign must be cast beforehand or std::copy must be used, since it uses a static_cast internally." );

fpSrc[5] = 2.3;
fpDst_cast[5] = (dst_t) 2.3;
fp32[5] = 2.3f;
fpDst[5] = (dst_t) 2.3;

WALBERLA_CHECK_FLOAT_EQUAL( (walberla::float32) fpSrc[0], fp32[0] );
WALBERLA_CHECK_FLOAT_EQUAL( (walberla::float32) fpSrc[9], fp32[9] );
WALBERLA_CHECK_FLOAT_EQUAL( (walberla::float32) fpSrc[5], fp32[5] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[0], fpDst_cast[0] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], fpDst_cast[9] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[5], fpDst_cast[5] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[0], fpDst[0] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[9], fpDst[9] );
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t) fpSrc[5], fpDst[5] );
WALBERLA_CHECK_EQUAL( typeid( fpDst ), typeid( fpDst_cast ) );

// Add up all elements of the vector to check whether the result is sufficiently correct.
{
const auto sumSrc = std::reduce(fpSrc.begin(), fpSrc.end());
const auto sumDst = std::reduce(fpDst.begin(), fpDst.end());
WALBERLA_CHECK_FLOAT_EQUAL( (dst_t)sumSrc, sumDst );
}
{
fpSrc.assign( 13, 1.3 );
std::copy( fpSrc.begin(), fpSrc.end(), fpDst.begin() );
const auto sumSrc = std::reduce(fpSrc.begin(), fpSrc.end());
const auto sumDst = std::reduce(fpDst.begin(), fpDst.end());
WALBERLA_CHECK_FLOAT_UNEQUAL( (dst_t)sumSrc, sumDst );
}
} // simple_Float16_test::vector_test()

int main( int argc, char** argv )
{
// This check only works since C++23 and is used in many implementations, so it's important, that it works.
WALBERLA_CHECK( std::is_arithmetic< dst_t >::value );

walberla::Environment env( argc, argv );
walberla::logging::Logging::instance()->setLogLevel( walberla::logging::Logging::INFO );
walberla::MPIManager::instance()->useWorldComm();

WALBERLA_LOG_INFO_ON_ROOT( " This run is executed with " << precisionType );
WALBERLA_LOG_INFO_ON_ROOT( " machine precision limit is " << precisionLimit );
const std::string stringLine( 125, '=' );
WALBERLA_LOG_INFO_ON_ROOT( stringLine );

WALBERLA_LOG_INFO_ON_ROOT( " Start a test with shared_pointer<float16[]>." );
simple_array_test();

WALBERLA_LOG_INFO_ON_ROOT( " Start a test with std::vector<float16>." );
vector_test();

WALBERLA_LOG_INFO_ON_ROOT( " Start a where float32 is sufficient but float16 is not." );
WALBERLA_CHECK_FLOAT_UNEQUAL( dst_t(1.0)-dst_t(0.3), 1.0-0.3 );
WALBERLA_CHECK_FLOAT_EQUAL( 1.0f-0.3f, 1.0-0.3 );

return 0;
} // simple_Float16_test::main()

} // namespace walberla::simple_Float16_test

int main( int argc, char** argv )
{
walberla::simple_Float16_test::main( argc, argv );

return EXIT_SUCCESS;
} // main()

0 comments on commit 686862b

Please sign in to comment.