Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,17 @@
#include "testing_nrm2_batched.hpp"
#include "testing_nrm2_strided_batched.hpp"
#include "testing_rot.hpp"
#include "testing_rot_batched.hpp"
#include "testing_rot_strided_batched.hpp"
#include "testing_rotg.hpp"
#include "testing_rotg_batched.hpp"
#include "testing_rotg_strided_batched.hpp"
#include "testing_rotm.hpp"
#include "testing_rotm_batched.hpp"
#include "testing_rotm_strided_batched.hpp"
#include "testing_rotmg.hpp"
#include "testing_rotmg_batched.hpp"
#include "testing_rotmg_strided_batched.hpp"
#include "testing_scal.hpp"
#include "testing_scal_batched.hpp"
#include "testing_scal_strided_batched.hpp"
Expand Down Expand Up @@ -174,14 +182,18 @@ struct perf_blas<
testing_set_get_vector<T>(arg);
else if(!strcmp(arg.function, "set_get_matrix"))
testing_set_get_matrix<T>(arg);
else if(!strcmp(arg.function, "rot"))
testing_rot<T>(arg);
else if(!strcmp(arg.function, "rotg"))
testing_rotg<T>(arg);
else if(!strcmp(arg.function, "rotm"))
testing_rotm<T>(arg);
else if(!strcmp(arg.function, "rotm_batched"))
testing_rotm_batched<T>(arg);
else if(!strcmp(arg.function, "rotm_strided_batched"))
testing_rotm_strided_batched<T>(arg);
else if(!strcmp(arg.function, "rotmg"))
testing_rotmg<T>(arg);
else if(!strcmp(arg.function, "rotmg_batched"))
testing_rotmg_batched<T>(arg);
else if(!strcmp(arg.function, "rotmg_strided_batched"))
testing_rotmg_strided_batched<T>(arg);
else if(!strcmp(arg.function, "gemv"))
testing_gemv<T>(arg);
else if(!strcmp(arg.function, "gemv_batched"))
Expand Down Expand Up @@ -326,6 +338,47 @@ struct perf_blas<T,
}
};

template <typename Ti, typename To = Ti, typename Tc = To, typename = void>
struct perf_blas_rot : rocblas_test_invalid
{
};

template <typename Ti, typename To, typename Tc>
struct perf_blas_rot<
Ti,
To,
Tc,
typename std::enable_if<(
(std::is_same<Ti, float>{} && std::is_same<Ti, To>{} && std::is_same<To, Tc>{})
|| (std::is_same<Ti, double>{} && std::is_same<Ti, To>{} && std::is_same<To, Tc>{})
|| (std::is_same<Ti, rocblas_float_complex>{} && std::is_same<To, float>{}
&& std::is_same<Tc, rocblas_float_complex>{})
|| (std::is_same<Ti, rocblas_float_complex>{} && std::is_same<To, float>{}
&& std::is_same<Tc, float>{})
|| (std::is_same<Ti, rocblas_double_complex>{} && std::is_same<To, double>{}
&& std::is_same<Tc, rocblas_double_complex>{})
|| (std::is_same<Ti, rocblas_double_complex>{} && std::is_same<To, double>{}
&& std::is_same<Tc, double>{}))>::type>
{
explicit operator bool()
{
return true;
}

void operator()(const Arguments& arg)
{
if(!strcmp(arg.function, "rot"))
testing_rot<Ti, To, Tc>(arg);
else if(!strcmp(arg.function, "rot_batched"))
testing_rot_batched<Ti, To, Tc>(arg);
else if(!strcmp(arg.function, "rot_strided_batched"))
testing_rot_strided_batched<Ti, To, Tc>(arg);
else
throw std::invalid_argument("Invalid combination --function "s + arg.function
+ " --a_type "s + rocblas_datatype2string(arg.a_type));
}
};

template <typename Ta, typename Tb = Ta, typename = void>
struct perf_blas_scal : rocblas_test_invalid
{
Expand Down Expand Up @@ -361,6 +414,40 @@ struct perf_blas_scal<
}
};

template <typename Ta, typename Tb = Ta, typename = void>
struct perf_blas_rotg : rocblas_test_invalid
{
};

template <typename Ta, typename Tb>
struct perf_blas_rotg<
Ta,
Tb,
typename std::enable_if<
(std::is_same<Ta, rocblas_double_complex>{} && std::is_same<Tb, double>{})
|| (std::is_same<Ta, rocblas_float_complex>{} && std::is_same<Tb, float>{})
|| (std::is_same<Ta, Tb>{} && std::is_same<Ta, float>{})
|| (std::is_same<Ta, Tb>{} && std::is_same<Ta, double>{})>::type>
{
explicit operator bool()
{
return true;
}
void operator()(const Arguments& arg)
{
if(!strcmp(arg.function, "rotg"))
testing_rotg<Ta, Tb>(arg);
else if(!strcmp(arg.function, "rotg_batched"))
testing_rotg_batched<Ta, Tb>(arg);
else if(!strcmp(arg.function, "rotg_strided_batched"))
testing_rotg_strided_batched<Ta, Tb>(arg);
else
throw std::invalid_argument("Invalid combination --function "s + arg.function
+ " --a_type " + rocblas_datatype2string(arg.a_type)
+ " --b_type " + rocblas_datatype2string(arg.b_type));
}
};

int run_bench_test(Arguments& arg)
{
// disable unit_check in client benchmark, it is only used in gtest unit test
Expand Down Expand Up @@ -523,6 +610,12 @@ int run_bench_test(Arguments& arg)
if(!strcmp(function, "scal") || !strcmp(function, "scal_batched")
|| !strcmp(function, "scal_strided_batched"))
rocblas_blas1_dispatch<perf_blas_scal>(arg);
else if(!strcmp(function, "rotg") || !strcmp(function, "rotg_batched")
|| !strcmp(function, "rotg_strided_batched"))
rocblas_blas1_dispatch<perf_blas_rotg>(arg);
else if(!strcmp(function, "rot") || !strcmp(function, "rot_batched")
|| !strcmp(function, "rot_strided_batched"))
rocblas_blas1_dispatch<perf_blas_rot>(arg);
else
rocblas_simple_dispatch<perf_blas>(arg);
}
Expand Down
26 changes: 25 additions & 1 deletion clients/common/rocblas_gentest.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,17 @@ def setdefaults(test):
if test['function'] in ('asum_strided_batched', 'nrm2_strided_batched',
'scal_strided_batched', 'swap_strided_batched',
'copy_strided_batched', 'dot_strided_batched',
'dotc_strided_batched'):
'dotc_strided_batched', 'rot_strided_batched',
'rotm_strided_batched'):
if all([x in test for x in ('N', 'incx', 'stride_scale')]):
test.setdefault('stride_x', int(test['N'] * abs(test['incx']) *
test['stride_scale']))
if all([x in test for x in ('N', 'incy', 'stride_scale')]):
test.setdefault('stride_y', int(test['N'] * abs(test['incy']) *
test['stride_scale']))
# we are using stride_c for param in rotm
if all([x in test for x in ('stride_scale')]):
test.setdefault('stride_c', int(test['stride_scale']) * 5)

if test['function'] in ('ger_strided_batched'):
if all([x in test for x in ('M', 'incx', 'stride_scale')]):
Expand All @@ -215,6 +219,26 @@ def setdefaults(test):
test.setdefault('stride_y', int(test['N'] * abs(test['incy']) *
test['stride_scale']))

# we are using stride_c for arg c and stride_d for arg s in rotg
# these are are single values for each batch
if test['function'] in ('rotg_strided_batched'):
if 'stride_scale' in test:
test.setdefault('stride_a', int(test['stride_scale']))
test.setdefault('stride_b', int(test['stride_scale']))
test.setdefault('stride_c', int(test['stride_scale']))
test.setdefault('stride_d', int(test['stride_scale']))
Comment thread
daineAMD marked this conversation as resolved.

# we are using stride_a for d1, stride_b for d2, and stride_c for param in
# rotmg. These are are single values for each batch, except param which is
# a 5 element array
if test['function'] in ('rotmg_strided_batched'):
if 'stride_scale' in test:
test.setdefault('stride_a', int(test['stride_scale']))
test.setdefault('stride_b', int(test['stride_scale']))
test.setdefault('stride_c', int(test['stride_scale']) * 5)
test.setdefault('stride_x', int(test['stride_scale']))
test.setdefault('stride_y', int(test['stride_scale']))

test.setdefault('stride_x', 0)
test.setdefault('stride_y', 0)

Expand Down
86 changes: 72 additions & 14 deletions clients/gtest/blas1_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@
#include "testing_nrm2_batched.hpp"
#include "testing_nrm2_strided_batched.hpp"
#include "testing_rot.hpp"
#include "testing_rot_batched.hpp"
#include "testing_rot_strided_batched.hpp"
#include "testing_rotg.hpp"
#include "testing_rotg_batched.hpp"
#include "testing_rotg_strided_batched.hpp"
#include "testing_rotm.hpp"
#include "testing_rotm_batched.hpp"
#include "testing_rotm_strided_batched.hpp"
#include "testing_rotmg.hpp"
#include "testing_rotmg_batched.hpp"
#include "testing_rotmg_strided_batched.hpp"
#include "testing_scal.hpp"
#include "testing_scal_batched.hpp"
#include "testing_scal_strided_batched.hpp"
Expand Down Expand Up @@ -59,9 +67,17 @@ namespace
swap_batched,
swap_strided_batched,
rot,
rot_batched,
rot_strided_batched,
rotg,
rotg_batched,
rotg_strided_batched,
rotm,
rotm_batched,
rotm_strided_batched,
rotmg,
rotmg_batched,
rotmg_strided_batched,
};

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -93,32 +109,45 @@ namespace
{
bool is_scal = (BLAS1 == blas1::scal || BLAS1 == blas1::scal_batched
|| BLAS1 == blas1::scal_strided_batched);
bool is_rot = (BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched
|| BLAS1 == blas1::rot_strided_batched);
bool is_rotg = (BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched
|| BLAS1 == blas1::rotg_strided_batched);
bool is_rotmg = (BLAS1 == blas1::rotmg || BLAS1 == blas1::rotmg_batched
|| BLAS1 == blas1::rotmg_strided_batched);
bool is_batched = (BLAS1 == blas1::nrm2_batched || BLAS1 == blas1::asum_batched
|| BLAS1 == blas1::scal_batched || BLAS1 == blas1::swap_batched
|| BLAS1 == blas1::copy_batched || BLAS1 == blas1::dot_batched
|| BLAS1 == blas1::dotc_batched);
|| BLAS1 == blas1::dotc_batched || BLAS1 == blas1::rot_batched
|| BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotg_batched
|| BLAS1 == blas1::rotmg_batched);
bool is_strided
= (BLAS1 == blas1::nrm2_strided_batched || BLAS1 == blas1::asum_strided_batched
|| BLAS1 == blas1::scal_strided_batched
|| BLAS1 == blas1::swap_strided_batched
|| BLAS1 == blas1::copy_strided_batched
|| BLAS1 == blas1::dot_strided_batched
|| BLAS1 == blas1::dotc_strided_batched);
|| BLAS1 == blas1::dotc_strided_batched
|| BLAS1 == blas1::rot_strided_batched
|| BLAS1 == blas1::rotm_strided_batched
|| BLAS1 == blas1::rotg_strided_batched
|| BLAS1 == blas1::rotmg_strided_batched);

if((is_scal || BLAS1 == blas1::rot || BLAS1 == blas1::rotg)
&& arg.a_type != arg.b_type)
if((is_scal || is_rotg || is_rot) && arg.a_type != arg.b_type)
name << '_' << rocblas_datatype2string(arg.b_type);
if(BLAS1 == blas1::rot && arg.compute_type != arg.a_type)
if(is_rot && arg.compute_type != arg.a_type)
name << '_' << rocblas_datatype2string(arg.compute_type);

name << '_' << arg.N;
if(!is_rotg && !is_rotmg)
name << '_' << arg.N;

if(BLAS1 == blas1::axpy || is_scal)
name << '_' << arg.alpha << "_" << arg.alphai;

name << '_' << arg.incx;
if(!is_rotg && !is_rotmg)
name << '_' << arg.incx;

if(is_strided)
if(is_strided && !is_rotg)
{
name << '_' << arg.stride_x;
}
Expand All @@ -129,17 +158,31 @@ namespace
|| BLAS1 == blas1::dotc_batched || BLAS1 == blas1::dot_strided_batched
|| BLAS1 == blas1::dotc_strided_batched || BLAS1 == blas1::swap
|| BLAS1 == blas1::swap_batched || BLAS1 == blas1::swap_strided_batched
|| BLAS1 == blas1::rot || BLAS1 == blas1::rotm)
|| BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched
|| BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm
|| BLAS1 == blas1::rotm_batched || BLAS1 == blas1::rotm_strided_batched)
{
name << '_' << arg.incy;
}

if(BLAS1 == blas1::swap_strided_batched || BLAS1 == blas1::copy_strided_batched
|| BLAS1 == blas1::dot_strided_batched || BLAS1 == blas1::dotc_strided_batched)
|| BLAS1 == blas1::dot_strided_batched || BLAS1 == blas1::dotc_strided_batched
|| BLAS1 == blas1::rot_strided_batched || BLAS1 == blas1::rotm_strided_batched)
{
name << '_' << arg.stride_y;
}

if(BLAS1 == blas1::rotg_strided_batched)
{
name << '_' << arg.stride_a << '_' << arg.stride_b << '_' << arg.stride_c << '_'
<< arg.stride_d;
}

if(BLAS1 == blas1::rotm_strided_batched || BLAS1 == blas1::rotmg_strided_batched)
{
name << '_' << arg.stride_c;
}

if(is_batched || is_strided)
{
name << "_" << arg.batch_count;
Expand Down Expand Up @@ -220,7 +263,8 @@ namespace
|| std::is_same<Ti, rocblas_float_complex>{}
|| std::is_same<Ti, rocblas_double_complex>{}))

|| (BLAS1 == blas1::rot
|| ((BLAS1 == blas1::rot || BLAS1 == blas1::rot_batched
|| BLAS1 == blas1::rot_strided_batched)
&& ((std::is_same<Ti, float>{} && std::is_same<Ti, To>{} && std::is_same<To, Tc>{})
|| (std::is_same<Ti, double>{} && std::is_same<Ti, To>{}
&& std::is_same<To, Tc>{})
Expand All @@ -233,16 +277,22 @@ namespace
|| (std::is_same<Ti, rocblas_double_complex>{} && std::is_same<To, double>{}
&& std::is_same<Tc, double>{})))

|| (BLAS1 == blas1::rotg && std::is_same<To, Tc>{}
|| ((BLAS1 == blas1::rotg || BLAS1 == blas1::rotg_batched
|| BLAS1 == blas1::rotg_strided_batched)
&& std::is_same<To, Tc>{}
&& ((std::is_same<Ti, float>{} && std::is_same<Ti, To>{})
|| (std::is_same<Ti, double>{} && std::is_same<Ti, To>{})
|| (std::is_same<Ti, rocblas_float_complex>{} && std::is_same<To, float>{})
|| (std::is_same<Ti, rocblas_double_complex>{} && std::is_same<To, double>{})))

|| (BLAS1 == blas1::rotm && std::is_same<To, Ti>{} && std::is_same<To, Tc>{}
|| ((BLAS1 == blas1::rotm || BLAS1 == blas1::rotm_batched
|| BLAS1 == blas1::rotm_strided_batched)
&& std::is_same<To, Ti>{} && std::is_same<To, Tc>{}
&& (std::is_same<Ti, float>{} || std::is_same<Ti, double>{}))

|| (BLAS1 == blas1::rotmg && std::is_same<To, Ti>{} && std::is_same<To, Tc>{}
|| ((BLAS1 == blas1::rotmg || BLAS1 == blas1::rotmg_batched
|| BLAS1 == blas1::rotmg_strided_batched)
&& std::is_same<To, Ti>{} && std::is_same<To, Tc>{}
&& (std::is_same<Ti, float>{} || std::is_same<Ti, double>{}))>;

// Creates tests for one of the BLAS 1 functions
Expand Down Expand Up @@ -320,9 +370,17 @@ BLAS1_TESTING(swap, ARG1)
BLAS1_TESTING(swap_batched, ARG1)
BLAS1_TESTING(swap_strided_batched, ARG1)
BLAS1_TESTING(rot, ARG3)
BLAS1_TESTING(rot_batched, ARG3)
BLAS1_TESTING(rot_strided_batched, ARG3)
BLAS1_TESTING(rotg, ARG2)
BLAS1_TESTING(rotg_batched, ARG2)
BLAS1_TESTING(rotg_strided_batched, ARG2)
BLAS1_TESTING(rotm, ARG1)
BLAS1_TESTING(rotm_batched, ARG1)
BLAS1_TESTING(rotm_strided_batched, ARG1)
BLAS1_TESTING(rotmg, ARG1)
BLAS1_TESTING(rotmg_batched, ARG1)
BLAS1_TESTING(rotmg_strided_batched, ARG1)

// clang-format on

Expand Down
Loading